inoki-giskard commited on
Commit
d65e913
1 Parent(s): 7cdd792

Fix mapping preview

Browse files
Files changed (1) hide show
  1. text_classification.py +11 -6
text_classification.py CHANGED
@@ -91,22 +91,27 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
91
  return column_mapping, prediction_result, None
92
 
93
  if isinstance(column_mapping["label"], dict):
94
- for model_label in id2label_mapping.keys():
95
- id2label_mapping[model_label] = column_mapping["label"][str(label2id[model_label])]
 
96
  elif None in id2label_mapping.values():
97
  column_mapping["label"] = {
98
  i: None for i in id2label.keys()
99
  }
100
  return column_mapping, prediction_result, None
101
 
 
 
 
102
  id2label_df = pd.DataFrame({
103
- "ID": [i for i in id2label.keys()],
104
- "Model labels": [id2label[label] for label in id2label.keys()],
105
- "Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
106
  })
107
  if "label" not in column_mapping.keys():
 
108
  column_mapping["label"] = {
109
- i: id2label_mapping[id2label[i]] for i in id2label.keys()
110
  }
111
 
112
  return column_mapping, prediction_result, id2label_df
 
91
  return column_mapping, prediction_result, None
92
 
93
  if isinstance(column_mapping["label"], dict):
94
+ # Use the column mapping passed by user
95
+ for i, model_label in column_mapping["label"].items():
96
+ id2label_mapping[model_label] = dataset_labels[int(i)]
97
  elif None in id2label_mapping.values():
98
  column_mapping["label"] = {
99
  i: None for i in id2label.keys()
100
  }
101
  return column_mapping, prediction_result, None
102
 
103
+ id2label_mapping = {
104
+ v: k for k, v in id2label_mapping.items()
105
+ }
106
  id2label_df = pd.DataFrame({
107
+ "ID": list(range(len(dataset_labels))),
108
+ "Labels": dataset_labels,
109
+ "Labels in original model": [f"{id2label_mapping[label]}({label2id[id2label_mapping[label]]})" for label in dataset_labels],
110
  })
111
  if "label" not in column_mapping.keys():
112
+ # Column mapping should contain original model labels
113
  column_mapping["label"] = {
114
+ str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
115
  }
116
 
117
  return column_mapping, prediction_result, id2label_df