ZeroCommand commited on
Commit
9f91529
1 Parent(s): 2569404

return label keys and add in mapping before submission

Browse files
text_classification.py CHANGED
@@ -26,23 +26,22 @@ def get_labels_and_features_from_dataset(ds):
26
  features = [f for f in dataset_features.keys() if not f.startswith("label")]
27
  if len(label_keys) == 0: # no labels found
28
  label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
29
- features += label_keys
30
 
31
  if len(label_keys) == 0: # no labels found
32
  # return everything for post processing
33
- return list(dataset_features.keys()), list(dataset_features.keys())
34
  if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
35
  if hasattr(dataset_features[label_keys[0]], 'feature'):
36
  label_feat = dataset_features[label_keys[0]].feature
37
  labels = label_feat.names
38
  else:
39
  labels = dataset_features[label_keys[0]].names
40
- return labels, features
41
  except Exception as e:
42
  logging.warning(
43
  f"Get Labels/Features Failed for dataset: {e}"
44
  )
45
- return None, None
46
 
47
  def check_model_task(model_id):
48
  # check if model is valid on huggingface
 
26
  features = [f for f in dataset_features.keys() if not f.startswith("label")]
27
  if len(label_keys) == 0: # no labels found
28
  label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
 
29
 
30
  if len(label_keys) == 0: # no labels found
31
  # return everything for post processing
32
+ return list(dataset_features.keys()), list(dataset_features.keys()), None
33
  if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
34
  if hasattr(dataset_features[label_keys[0]], 'feature'):
35
  label_feat = dataset_features[label_keys[0]].feature
36
  labels = label_feat.names
37
  else:
38
  labels = dataset_features[label_keys[0]].names
39
+ return labels, features, label_keys
40
  except Exception as e:
41
  logging.warning(
42
  f"Get Labels/Features Failed for dataset: {e}"
43
  )
44
+ return None, None, None
45
 
46
  def check_model_task(model_id):
47
  # check if model is valid on huggingface
text_classification_ui_helpers.py CHANGED
@@ -198,7 +198,7 @@ def precheck_model_ds_enable_example_btn(
198
  try:
199
  ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
200
  df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
201
- ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])
202
 
203
  if model_task is None or model_task != "text-classification":
204
  gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
@@ -300,7 +300,7 @@ def align_columns_and_show_prediction(
300
  model_labels = list(prediction_response.keys())
301
 
302
  ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
303
- ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
304
 
305
  # when dataset does not have labels or features
306
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
@@ -390,7 +390,7 @@ def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, da
390
  return gr.update(interactive=False)
391
  return gr.update(interactive=True)
392
 
393
- def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
394
  label_mapping = {}
395
  if len(all_mappings["labels"].keys()) != len(ds_labels):
396
  logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
@@ -408,9 +408,9 @@ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
408
  logger.warning("features not in all_mappings")
409
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
410
 
411
- special_classlabel = [feature for feature in ds_features if feature.startswith("label")]
412
  feature_mapping = all_mappings["features"]
413
- feature_mapping.update({"label": special_classlabel[0] if special_classlabel else "label"})
 
414
  return label_mapping, feature_mapping
415
 
416
  def show_hf_token_info(token):
@@ -426,8 +426,8 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
426
 
427
  # get ds labels and features again for alignment
428
  ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
429
- ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
430
- label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features)
431
 
432
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
433
  save_job_to_pipe(
 
198
  try:
199
  ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
200
  df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
201
+ ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds[dataset_split])
202
 
203
  if model_task is None or model_task != "text-classification":
204
  gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
 
300
  model_labels = list(prediction_response.keys())
301
 
302
  ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
303
+ ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)
304
 
305
  # when dataset does not have labels or features
306
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
 
390
  return gr.update(interactive=False)
391
  return gr.update(interactive=True)
392
 
393
+ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
394
  label_mapping = {}
395
  if len(all_mappings["labels"].keys()) != len(ds_labels):
396
  logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
 
408
  logger.warning("features not in all_mappings")
409
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
410
 
 
411
  feature_mapping = all_mappings["features"]
412
+ if len(label_keys) > 0:
413
+ feature_mapping.update({"label": label_keys[0]})
414
  return label_mapping, feature_mapping
415
 
416
  def show_hf_token_info(token):
 
426
 
427
  # get ds labels and features again for alignment
428
  ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
429
+ ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
430
+ label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys)
431
 
432
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
433
  save_job_to_pipe(