echarlaix HF staff commited on
Commit
da32672
·
1 Parent(s): e23b1fe

remove quant method

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -35,7 +35,6 @@ from optimum.intel import (
35
  def process_model(
36
  model_id: str,
37
  dtype: str,
38
- quant_method: str,
39
  calibration_dataset: str,
40
  ratio: str,
41
  private_repo: bool,
@@ -49,9 +48,6 @@ def process_model(
49
  username = whoami(oauth_token.token)["name"]
50
  new_repo_id = f"{username}/{model_name}-openvino-{dtype}"
51
 
52
- if quant_method != "default":
53
- new_repo_id += f"-{quant_method}"
54
-
55
  task = TasksManager.infer_task_from_model(model_id)
56
  if task not in _HEAD_TO_AUTOMODELS:
57
  raise ValueError(
@@ -68,11 +64,21 @@ def process_model(
68
  use_auth_token=oauth_token.token,
69
  )
70
  export = len(ov_files) == 0
 
 
 
 
 
 
 
 
 
 
71
  quantization_config = OVWeightQuantizationConfig(
72
- bits=8 if dtype == "int8" else 4,
73
  quant_method=quant_method,
74
- dataset=calibration_dataset,
75
- ratio=1.0 if dtype == "int8" else ratio,
76
  )
77
 
78
  api = HfApi(token=oauth_token.token)
@@ -166,6 +172,7 @@ dtype = gr.Dropdown(
166
  filterable=False,
167
  visible=True,
168
  )
 
169
  quant_method = gr.Dropdown(
170
  ["default", "awq", "hybrid"],
171
  value="default",
@@ -173,6 +180,7 @@ quant_method = gr.Dropdown(
173
  filterable=False,
174
  visible=True,
175
  )
 
176
  calibration_dataset = gr.Dropdown(
177
  [
178
  "wikitext2",
@@ -210,7 +218,6 @@ interface = gr.Interface(
210
  inputs=[
211
  model_id,
212
  dtype,
213
- quant_method,
214
  calibration_dataset,
215
  ratio,
216
  private_repo,
 
35
  def process_model(
36
  model_id: str,
37
  dtype: str,
 
38
  calibration_dataset: str,
39
  ratio: str,
40
  private_repo: bool,
 
48
  username = whoami(oauth_token.token)["name"]
49
  new_repo_id = f"{username}/{model_name}-openvino-{dtype}"
50
 
 
 
 
51
  task = TasksManager.infer_task_from_model(model_id)
52
  if task not in _HEAD_TO_AUTOMODELS:
53
  raise ValueError(
 
64
  use_auth_token=oauth_token.token,
65
  )
66
  export = len(ov_files) == 0
67
+
68
+ is_int8 = dtype == "int8"
69
+ library_name = TasksManager.infer_library_from_model(model_id)
70
+ if library_name == "diffusers":
71
+ quant_method = "hybrid"
72
+ elif not is_int8:
73
+ quant_method = "awq"
74
+ else:
75
+ quant_method = "default"
76
+
77
  quantization_config = OVWeightQuantizationConfig(
78
+ bits=8 if is_int8 else 4,
79
  quant_method=quant_method,
80
+ dataset=None if quant_method=="default" else calibration_dataset,
81
+ ratio=1.0 if is_int8 else ratio,
82
  )
83
 
84
  api = HfApi(token=oauth_token.token)
 
172
  filterable=False,
173
  visible=True,
174
  )
175
+ """
176
  quant_method = gr.Dropdown(
177
  ["default", "awq", "hybrid"],
178
  value="default",
 
180
  filterable=False,
181
  visible=True,
182
  )
183
+ """
184
  calibration_dataset = gr.Dropdown(
185
  [
186
  "wikitext2",
 
218
  inputs=[
219
  model_id,
220
  dtype,
 
221
  calibration_dataset,
222
  ratio,
223
  private_repo,