Tal commited on
Commit
9c268d4
1 Parent(s): 76a6033

Add ONNX Runtime quantization support

Browse files
Files changed (2) hide show
  1. export_onnx.py +19 -2
  2. requirements.txt +1 -1
export_onnx.py CHANGED
@@ -1,6 +1,10 @@
1
  import copy
2
  import argparse
3
  from optimum.exporters.onnx import onnx_export_from_model
 
 
 
 
4
  from collections import OrderedDict
5
  from typing import Dict
6
  from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
@@ -27,7 +31,7 @@ class BGEM3OnnxConfig(XLMRobertaOnnxConfig):
27
  )
28
 
29
 
30
- def main(output: str, opset: int, device: str, optimize: str, atol: str):
31
  model = BGEM3InferenceModel()
32
  bgem3_onnx_config = BGEM3OnnxConfig(model.config)
33
  onnx_export_from_model(
@@ -40,6 +44,12 @@ def main(output: str, opset: int, device: str, optimize: str, atol: str):
40
  atol=atol,
41
  device=device,
42
  )
 
 
 
 
 
 
43
 
44
 
45
  if __name__ == "__main__":
@@ -80,6 +90,13 @@ if __name__ == "__main__":
80
  default=None,
81
  help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
82
  )
 
 
 
 
 
 
 
83
  args = parser.parse_args()
84
 
85
- main(args.output, args.opset, args.device, args.optimize, args.atol)
 
1
  import copy
2
  import argparse
3
  from optimum.exporters.onnx import onnx_export_from_model
4
+ # quantization import
5
+ from optimum.onnxruntime import ORTQuantizer
6
+ from optimum.onnxruntime.configuration import AutoQuantizationConfig
7
+
8
  from collections import OrderedDict
9
  from typing import Dict
10
  from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig
 
31
  )
32
 
33
 
34
+ def main(output: str, opset: int, device: str, optimize: str, atol: str, quantize: bool= False):
35
  model = BGEM3InferenceModel()
36
  bgem3_onnx_config = BGEM3OnnxConfig(model.config)
37
  onnx_export_from_model(
 
44
  atol=atol,
45
  device=device,
46
  )
47
+ if quantize:
48
+ quantizer = ORTQuantizer.from_pretrained(output, file_name="model.onnx")
49
+ qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)
50
+ quantizer.quantize(save_dir=output, quantization_config=qconfig)
51
+
52
+
53
 
54
 
55
  if __name__ == "__main__":
 
90
  default=None,
91
  help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
92
  )
93
+
94
+ parser.add_argument(
95
+ "--quantize",
96
+ type=bool,
97
+ default=False,
98
+ help="If specified, the model will be quantized using ONNX Runtime quantization",
99
+ )
100
  args = parser.parse_args()
101
 
102
+ main(args.output, args.opset, args.device, args.optimize, args.atol, args.quantize)
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  accelerate==0.27.2
2
  huggingface-hub==0.20.3
3
  onnx==1.15.0
4
- onnxruntime==1.17.0
5
  optimum==1.17.0
6
  torch==2.2.0
7
  transformers==4.37.2
 
1
  accelerate==0.27.2
2
  huggingface-hub==0.20.3
3
  onnx==1.15.0
4
+ onnxruntime==1.16.0
5
  optimum==1.17.0
6
  torch==2.2.0
7
  transformers==4.37.2