ZeroCommand commited on
Commit
804fff2
1 Parent(s): 85a8a8b

restructure and empty col mapping when comfig changes

Browse files
app_text_classification.py CHANGED
@@ -7,12 +7,14 @@ from text_classification_ui_helpers import (
7
  get_related_datasets_from_leaderboard,
8
  align_columns_and_show_prediction,
9
  check_dataset,
10
- precheck_model_ds_enable_example_btn,
11
  show_hf_token_info,
 
12
  try_submit,
 
13
  write_column_mapping_to_config,
 
14
  )
15
- from text_classification import get_example_prediction, HuggingFaceInferenceAPIResponse
16
  import logging
17
  from wordings import (
18
  CONFIRM_MAPPING_DETAILS_MD,
@@ -157,6 +159,12 @@ def get_demo():
157
  outputs=[dataset_config_input, dataset_split_input, loading_status]
158
  )
159
 
 
 
 
 
 
 
160
  gr.on(
161
  triggers=[label.change for label in column_mappings],
162
  fn=write_column_mapping_to_config,
@@ -234,24 +242,6 @@ def get_demo():
234
  outputs=[run_btn, logs, uid_label],
235
  )
236
 
237
- def enable_run_btn(run_inference, inference_token, model_id, dataset_id, dataset_config, dataset_split):
238
- if not run_inference or inference_token == "":
239
- logger.warn("Inference API is not enabled")
240
- return gr.update(interactive=False)
241
- if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
242
- logger.warn("Model id or dataset id is not selected")
243
- return gr.update(interactive=False)
244
- if not column_mapping_accordion.visible:
245
- logger.warn("Column mapping is not confirmed")
246
- return gr.update(interactive=False)
247
- _, prediction_response = get_example_prediction(
248
- model_id, dataset_id, dataset_config, dataset_split, inference_token
249
- )
250
- if not isinstance(prediction_response, HuggingFaceInferenceAPIResponse):
251
- logger.warn("Failed to get example prediction from inference API with this token")
252
- return gr.update(interactive=False)
253
- return gr.update(interactive=True)
254
-
255
  gr.on(
256
  triggers=[
257
  run_inference.input,
@@ -260,6 +250,7 @@ def get_demo():
260
  ],
261
  fn=enable_run_btn,
262
  inputs=[
 
263
  run_inference,
264
  inference_token,
265
  model_id_input,
@@ -274,6 +265,7 @@ def get_demo():
274
  triggers=[label.input for label in column_mappings],
275
  fn=enable_run_btn,
276
  inputs=[
 
277
  run_inference,
278
  inference_token,
279
  model_id_input,
 
7
  get_related_datasets_from_leaderboard,
8
  align_columns_and_show_prediction,
9
  check_dataset,
 
10
  show_hf_token_info,
11
+ precheck_model_ds_enable_example_btn,
12
  try_submit,
13
+ empty_column_mapping,
14
  write_column_mapping_to_config,
15
+ enable_run_btn,
16
  )
17
+
18
  import logging
19
  from wordings import (
20
  CONFIRM_MAPPING_DETAILS_MD,
 
159
  outputs=[dataset_config_input, dataset_split_input, loading_status]
160
  )
161
 
162
+ gr.on(
163
+ triggers=[model_id_input.change, dataset_id_input.change, dataset_config_input.change],
164
+ fn=empty_column_mapping,
165
+ inputs=[uid_label]
166
+ )
167
+
168
  gr.on(
169
  triggers=[label.change for label in column_mappings],
170
  fn=write_column_mapping_to_config,
 
242
  outputs=[run_btn, logs, uid_label],
243
  )
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  gr.on(
246
  triggers=[
247
  run_inference.input,
 
250
  ],
251
  fn=enable_run_btn,
252
  inputs=[
253
+ uid_label,
254
  run_inference,
255
  inference_token,
256
  model_id_input,
 
265
  triggers=[label.input for label in column_mappings],
266
  fn=enable_run_btn,
267
  inputs=[
268
+ uid_label,
269
  run_inference,
270
  inference_token,
271
  model_id_input,
text_classification_ui_helpers.py CHANGED
@@ -80,7 +80,8 @@ def check_dataset(dataset_id):
80
  ""
81
  )
82
 
83
-
 
84
 
85
  def write_column_mapping_to_config(uid, *labels):
86
  # TODO: Substitute 'text' with more features for zero-shot
@@ -99,7 +100,6 @@ def write_column_mapping_to_config(uid, *labels):
99
 
100
  write_column_mapping(all_mappings, uid)
101
 
102
-
103
  def export_mappings(all_mappings, key, subkeys, values):
104
  if key not in all_mappings.keys():
105
  all_mappings[key] = dict()
@@ -308,12 +308,34 @@ def align_columns_and_show_prediction(
308
  def check_column_mapping_keys_validity(all_mappings):
309
  if all_mappings is None:
310
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
311
- return (gr.update(interactive=True), gr.update(visible=False))
312
 
313
  if "labels" not in all_mappings.keys():
314
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
315
- return (gr.update(interactive=True), gr.update(visible=False))
 
 
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
319
  label_mapping = {}
@@ -340,7 +362,8 @@ def show_hf_token_info(token):
340
 
341
  def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
342
  all_mappings = read_column_mapping(uid)
343
- check_column_mapping_keys_validity(all_mappings)
 
344
 
345
  # get ds labels and features again for alignment
346
  ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
 
80
  ""
81
  )
82
 
83
+ def empty_column_mapping(uid):
84
+ write_column_mapping(None, uid)
85
 
86
  def write_column_mapping_to_config(uid, *labels):
87
  # TODO: Substitute 'text' with more features for zero-shot
 
100
 
101
  write_column_mapping(all_mappings, uid)
102
 
 
103
  def export_mappings(all_mappings, key, subkeys, values):
104
  if key not in all_mappings.keys():
105
  all_mappings[key] = dict()
 
308
  def check_column_mapping_keys_validity(all_mappings):
309
  if all_mappings is None:
310
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
311
+ return False
312
 
313
  if "labels" not in all_mappings.keys():
314
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
315
+ return False
316
+
317
+ return True
318
 
319
+ def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, dataset_config, dataset_split):
320
+ if not run_inference or inference_token == "":
321
+ logger.warn("Inference API is not enabled")
322
+ return gr.update(interactive=False)
323
+ if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
324
+ logger.warn("Model id or dataset id is not selected")
325
+ return gr.update(interactive=False)
326
+
327
+ all_mappings = read_column_mapping(uid)
328
+ if not check_column_mapping_keys_validity(all_mappings):
329
+ logger.warn("Column mapping is not valid")
330
+ return gr.update(interactive=False)
331
+
332
+ _, prediction_response = get_example_prediction(
333
+ model_id, dataset_id, dataset_config, dataset_split, inference_token
334
+ )
335
+ if not isinstance(prediction_response, HuggingFaceInferenceAPIResponse):
336
+ logger.warn("Failed to get example prediction from inference API with this token")
337
+ return gr.update(interactive=False)
338
+ return gr.update(interactive=True)
339
 
340
  def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
341
  label_mapping = {}
 
362
 
363
  def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
364
  all_mappings = read_column_mapping(uid)
365
+ if not check_column_mapping_keys_validity(all_mappings):
366
+ return (gr.update(interactive=True), gr.update(visible=False))
367
 
368
  # get ds labels and features again for alignment
369
  ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)