wenjiao commited on
Commit
111e33c
1 Parent(s): 046adc3

update ComputeDtype logic

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -116,12 +116,15 @@ def update_table(
116
  group_dtype: str
117
  ):
118
 
119
- compute_dtype = [compute_dtype]
120
-
121
  if weight_dtype == 'All':
122
  weight_dtype = ['int2', 'int3', 'int4', 'nf4', 'fp4']
123
  else:
124
  weight_dtype = [weight_dtype]
 
 
 
 
 
125
 
126
  if group_dtype == 'All':
127
  group_dtype = [-1, 1024, 256, 128, 64, 32]
@@ -296,7 +299,7 @@ with demo:
296
  with gr.Box() as config:
297
  gr.HTML("""<p style='padding-bottom: 0.5rem; color: #6b7280; '>Quantization config</p>""")
298
  with gr.Row():
299
- filter_columns_computeDtype = gr.Dropdown(choices=[i.value.name for i in ComputeDtype], label="Compute Dtype", multiselect=False, value="float16", interactive=True,)
300
  filter_columns_weightDtype = gr.Dropdown(choices=[i.value.name for i in WeightDtype], label="Weight Dtype", multiselect=False, value="All", interactive=True,)
301
  filter_columns_doubleQuant = gr.Dropdown(choices=["True", "False"], label="Double Quant", multiselect=False, value=False, interactive=True)
302
  filter_columns_groupDtype = gr.Dropdown(choices=[i.value.name for i in GroupDtype], label="Group Size", multiselect=False, value="All", interactive=True,)
 
116
  group_dtype: str
117
  ):
118
 
 
 
119
  if weight_dtype == 'All':
120
  weight_dtype = ['int2', 'int3', 'int4', 'nf4', 'fp4']
121
  else:
122
  weight_dtype = [weight_dtype]
123
+
124
+ if compute_dtype == 'All':
125
+ compute_dtype = ['bfloat16', 'float16', 'int8', 'float32']
126
+ else:
127
+ compute_dtype = [compute_dtype]
128
 
129
  if group_dtype == 'All':
130
  group_dtype = [-1, 1024, 256, 128, 64, 32]
 
299
  with gr.Box() as config:
300
  gr.HTML("""<p style='padding-bottom: 0.5rem; color: #6b7280; '>Quantization config</p>""")
301
  with gr.Row():
302
+ filter_columns_computeDtype = gr.Dropdown(choices=[i.value.name for i in ComputeDtype], label="Compute Dtype", multiselect=False, value="All", interactive=True,)
303
  filter_columns_weightDtype = gr.Dropdown(choices=[i.value.name for i in WeightDtype], label="Weight Dtype", multiselect=False, value="All", interactive=True,)
304
  filter_columns_doubleQuant = gr.Dropdown(choices=["True", "False"], label="Double Quant", multiselect=False, value=False, interactive=True)
305
  filter_columns_groupDtype = gr.Dropdown(choices=[i.value.name for i in GroupDtype], label="Group Size", multiselect=False, value="All", interactive=True,)