ZeroCommand commited on
Commit
ba41a5c
1 Parent(s): f0a313e

hide dropdown menus when the labels match

Browse files
Files changed (1) hide show
  1. app_text_classification.py +33 -15
app_text_classification.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import time
5
  import subprocess
6
  import logging
 
7
 
8
  import json
9
 
@@ -169,14 +170,12 @@ def get_demo():
169
  all_mappings["features"][feat] = ds_features[i]
170
  write_column_mapping(all_mappings)
171
 
172
- def list_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split, model_id2label, model_features):
173
- ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
174
- if ds_labels is None or ds_features is None:
175
- return [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
176
  model_labels = list(model_id2label.values())
177
  lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])]
178
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
179
- features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in model_features]
 
180
  features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))]
181
  return lables + features
182
 
@@ -196,24 +195,43 @@ def get_demo():
196
  gr.update(visible=False),
197
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
198
  )
199
- model_id2label = ppl.model.config.id2label
200
- model_features = ['text']
201
- column_mappings = list_labels_and_features_from_dataset(
202
- dataset_id,
203
- dataset_config,
204
- dataset_split,
205
- model_id2label,
206
- model_features
207
- )
208
-
209
  if ppl is None:
210
  gr.Warning("Model not found")
211
  return (
212
  gr.update(visible=False),
213
  gr.update(visible=False),
214
  gr.update(visible=False, open=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  *column_mappings
216
  )
 
217
  prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split)
218
  return (
219
  gr.update(value=prediction_input, visible=True),
 
4
  import time
5
  import subprocess
6
  import logging
7
+ import threading
8
 
9
  import json
10
 
 
170
  all_mappings["features"][feat] = ds_features[i]
171
  write_column_mapping(all_mappings)
172
 
173
+ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
 
 
 
174
  model_labels = list(model_id2label.values())
175
  lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])]
176
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
177
+ # TODO: Substitute 'text' with more features for zero-shot
178
+ features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in ['text']]
179
  features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))]
180
  return lables + features
181
 
 
195
  gr.update(visible=False),
196
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
197
  )
198
+
199
+ dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
200
+
 
 
 
 
 
 
 
201
  if ppl is None:
202
  gr.Warning("Model not found")
203
  return (
204
  gr.update(visible=False),
205
  gr.update(visible=False),
206
  gr.update(visible=False, open=False),
207
+ *dropdown_placement
208
+ )
209
+ model_id2label = ppl.model.config.id2label
210
+ ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
211
+
212
+ if ds_labels is None or ds_features is None:
213
+ gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
214
+ return (
215
+ gr.update(visible=False),
216
+ gr.update(visible=False),
217
+ gr.update(visible=False, open=False),
218
+ *dropdown_placement
219
+ )
220
+
221
+ column_mappings = list_labels_and_features_from_dataset(
222
+ ds_labels,
223
+ ds_features,
224
+ model_id2label,
225
+ )
226
+ if model_id2label.items() != ds_labels.items():
227
+ gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
228
+ return (
229
+ gr.update(visible=False),
230
+ gr.update(visible=False),
231
+ gr.update(visible=True, open=True),
232
  *column_mappings
233
  )
234
+
235
  prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split)
236
  return (
237
  gr.update(value=prediction_input, visible=True),