Tal
commited on
Commit
•
9c268d4
1
Parent(s):
76a6033
Add ONNX Runtime quantization support
Browse files- export_onnx.py +19 -2
- requirements.txt +1 -1
export_onnx.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
import copy
|
2 |
import argparse
|
3 |
from optimum.exporters.onnx import onnx_export_from_model
|
|
|
|
|
|
|
|
|
4 |
from collections import OrderedDict
|
5 |
from typing import Dict
|
6 |
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
|
@@ -27,7 +31,7 @@ class BGEM3OnnxConfig(XLMRobertaOnnxConfig):
|
|
27 |
)
|
28 |
|
29 |
|
30 |
-
def main(output: str, opset: int, device: str, optimize: str, atol: str):
|
31 |
model = BGEM3InferenceModel()
|
32 |
bgem3_onnx_config = BGEM3OnnxConfig(model.config)
|
33 |
onnx_export_from_model(
|
@@ -40,6 +44,12 @@ def main(output: str, opset: int, device: str, optimize: str, atol: str):
|
|
40 |
atol=atol,
|
41 |
device=device,
|
42 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
if __name__ == "__main__":
|
@@ -80,6 +90,13 @@ if __name__ == "__main__":
|
|
80 |
default=None,
|
81 |
help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
|
82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
args = parser.parse_args()
|
84 |
|
85 |
-
main(args.output, args.opset, args.device, args.optimize, args.atol)
|
|
|
1 |
import copy
|
2 |
import argparse
|
3 |
from optimum.exporters.onnx import onnx_export_from_model
|
4 |
+
# quantization import
|
5 |
+
from optimum.onnxruntime import ORTQuantizer
|
6 |
+
from optimum.onnxruntime.configuration import AutoQuantizationConfig
|
7 |
+
|
8 |
from collections import OrderedDict
|
9 |
from typing import Dict
|
10 |
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
|
|
|
31 |
)
|
32 |
|
33 |
|
34 |
+
def main(output: str, opset: int, device: str, optimize: str, atol: str, quantize: bool= False):
|
35 |
model = BGEM3InferenceModel()
|
36 |
bgem3_onnx_config = BGEM3OnnxConfig(model.config)
|
37 |
onnx_export_from_model(
|
|
|
44 |
atol=atol,
|
45 |
device=device,
|
46 |
)
|
47 |
+
if quantize:
|
48 |
+
quantizer = ORTQuantizer.from_pretrained(output, file_name="model.onnx")
|
49 |
+
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)
|
50 |
+
quantizer.quantize(save_dir=output, quantization_config=qconfig)
|
51 |
+
|
52 |
+
|
53 |
|
54 |
|
55 |
if __name__ == "__main__":
|
|
|
90 |
default=None,
|
91 |
help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
|
92 |
)
|
93 |
+
|
94 |
+
parser.add_argument(
|
95 |
+
"--quantize",
|
96 |
+
type=bool,
|
97 |
+
default=False,
|
98 |
+
help="If specified, the model will be quantized using ONNX Runtime quantization",
|
99 |
+
)
|
100 |
args = parser.parse_args()
|
101 |
|
102 |
+
main(args.output, args.opset, args.device, args.optimize, args.atol, args.quantize)
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
accelerate==0.27.2
|
2 |
huggingface-hub==0.20.3
|
3 |
onnx==1.15.0
|
4 |
-
onnxruntime==1.
|
5 |
optimum==1.17.0
|
6 |
torch==2.2.0
|
7 |
transformers==4.37.2
|
|
|
1 |
accelerate==0.27.2
|
2 |
huggingface-hub==0.20.3
|
3 |
onnx==1.15.0
|
4 |
+
onnxruntime==1.16.0
|
5 |
optimum==1.17.0
|
6 |
torch==2.2.0
|
7 |
transformers==4.37.2
|