ZeroCommand commited on
Commit
5058ff3
1 Parent(s): ba41a5c

hide dropdown menus when the labels match when labels not matching

Browse files
Files changed (1) hide show
  1. app_text_classification.py +9 -5
app_text_classification.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import time
5
  import subprocess
6
  import logging
7
- import threading
8
 
9
  import json
10
 
@@ -198,7 +198,7 @@ def get_demo():
198
 
199
  dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
200
 
201
- if ppl is None:
202
  gr.Warning("Model not found")
203
  return (
204
  gr.update(visible=False),
@@ -209,7 +209,8 @@ def get_demo():
209
  model_id2label = ppl.model.config.id2label
210
  ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
211
 
212
- if ds_labels is None or ds_features is None:
 
213
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
214
  return (
215
  gr.update(visible=False),
@@ -223,7 +224,10 @@ def get_demo():
223
  ds_features,
224
  model_id2label,
225
  )
226
- if model_id2label.items() != ds_labels.items():
 
 
 
227
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
228
  return (
229
  gr.update(visible=False),
@@ -236,7 +240,7 @@ def get_demo():
236
  return (
237
  gr.update(value=prediction_input, visible=True),
238
  gr.update(value=prediction_output, visible=True),
239
- gr.update(open=False),
240
  *column_mappings
241
  )
242
 
 
4
  import time
5
  import subprocess
6
  import logging
7
+ import collections
8
 
9
  import json
10
 
 
198
 
199
  dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
200
 
201
+ if ppl is None: # pipeline not found
202
  gr.Warning("Model not found")
203
  return (
204
  gr.update(visible=False),
 
209
  model_id2label = ppl.model.config.id2label
210
  ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
211
 
212
+ # when dataset does not have labels or features
213
+ if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
214
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
215
  return (
216
  gr.update(visible=False),
 
224
  ds_features,
225
  model_id2label,
226
  )
227
+
228
+ # when labels or features are not aligned
229
+ # show manually column mapping
230
+ if collections.Counter(model_id2label.items()) != collections.Counter(ds_labels) or ds_features[0] != 'text':
231
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
232
  return (
233
  gr.update(visible=False),
 
240
  return (
241
  gr.update(value=prediction_input, visible=True),
242
  gr.update(value=prediction_output, visible=True),
243
+ gr.update(visible=True, open=False),
244
  *column_mappings
245
  )
246