GSK-2396-allow-edit-feature-mappings

#12
by ZeroCommand - opened
Files changed (4) hide show
  1. app.py +69 -41
  2. config.yaml +9 -0
  3. text_classification.py +56 -34
  4. utils.py +54 -0
app.py CHANGED
@@ -11,13 +11,12 @@ import json
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
  from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
14
-
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
17
  HF_SPACE_ID = 'SPACE_ID'
18
  HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
19
 
20
-
21
  theme = gr.themes.Soft(
22
  primary_hue="green",
23
  )
@@ -70,6 +69,7 @@ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_ma
70
  gr.update(visible=False), # Model prediction input
71
  gr.update(visible=False), # Model prediction preview
72
  gr.update(visible=False), # Label mapping preview
 
73
  )
74
  if isinstance(ppl, Exception):
75
  gr.Warning(f'Failed to load model": {ppl}')
@@ -80,6 +80,7 @@ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_ma
80
  gr.update(visible=False), # Model prediction input
81
  gr.update(visible=False), # Model prediction preview
82
  gr.update(visible=False), # Label mapping preview
 
83
  )
84
 
85
  # Validate dataset
@@ -105,7 +106,7 @@ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_ma
105
  gr.update(visible=False), # Model prediction input
106
  gr.update(visible=False), # Model prediction preview
107
  gr.update(visible=False), # Label mapping preview
108
- # gr.update(visible=True), # Column mapping
109
  )
110
 
111
  # TODO: Validate column mapping by running once
@@ -118,21 +119,21 @@ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_ma
118
  except Exception:
119
  column_mapping = {}
120
 
121
- column_mapping, prediction_input, prediction_result, id2label_df = \
122
  text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
123
 
124
  column_mapping = json.dumps(column_mapping, indent=2)
125
 
126
- if prediction_result is None:
127
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
128
  return (
129
  gr.update(interactive=False), # Submit button
130
- gr.update(visible=True), # Loading row
131
- gr.update(visible=False), # Preview row
132
- gr.update(visible=False), # Model prediction input
133
  gr.update(visible=False), # Model prediction preview
134
- gr.update(visible=False), # Label mapping preview
135
- # gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
136
  )
137
  elif id2label_df is None:
138
  gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
@@ -142,8 +143,8 @@ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_ma
142
  gr.update(visible=True), # Preview row
143
  gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
144
  gr.update(value=prediction_result, visible=True), # Model prediction preview
145
- gr.update(visible=False), # Label mapping preview
146
- # gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
147
  )
148
 
149
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
@@ -155,13 +156,18 @@ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_ma
155
  gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
156
  gr.update(value=prediction_result, visible=True), # Model prediction preview
157
  gr.update(value=id2label_df, visible=True, interactive=True), # Label mapping preview
 
158
  )
159
 
160
 
161
- def try_submit(m_id, d_id, config, split, column_mappings, local):
162
  label_mapping = {}
163
- for i, label in column_mappings["Model Prediction Labels"].items():
164
  label_mapping.update({str(i): label})
 
 
 
 
165
 
166
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
167
 
@@ -178,8 +184,9 @@ def try_submit(m_id, d_id, config, split, column_mappings, local):
178
  "--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
179
  "--output_format", "markdown",
180
  "--output_portal", "huggingface",
181
- # TODO: "--feature_mapping", json.dumps(column_mapping),
182
  "--label_mapping", json.dumps(label_mapping),
 
183
  ]
184
 
185
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
@@ -221,12 +228,15 @@ with gr.Blocks(theme=theme) as iface:
221
  gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
222
  pass
223
 
224
- def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
225
  column_mapping = '{}'
226
- m_id, ppl = check_model(model_id=model_id)
227
 
228
  if id2label_mapping_dataframe is not None:
229
- column_mapping = id2label_mapping_dataframe.to_json(orient="split")
 
 
 
230
  if check_column_mapping_keys_validity(column_mapping, ppl) is False:
231
  gr.Warning('Label mapping table has invalid contents. Please check again.')
232
  return (gr.update(interactive=False),
@@ -234,18 +244,18 @@ with gr.Blocks(theme=theme) as iface:
234
  gr.update(),
235
  gr.update(),
236
  gr.update(),
 
237
  gr.update())
238
  else:
239
  if model_id and dataset_id and dataset_config and dataset_split:
240
- return try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping)
241
  else:
242
- del ppl
243
-
244
  return (gr.update(interactive=False),
245
  gr.update(visible=True),
246
  gr.update(visible=False),
247
  gr.update(visible=False),
248
  gr.update(visible=False),
 
249
  gr.update(visible=False))
250
  with gr.Row():
251
  gr.Markdown('''
@@ -256,6 +266,13 @@ with gr.Blocks(theme=theme) as iface:
256
  ''')
257
  with gr.Row():
258
  run_local = gr.Checkbox(value=True, label="Run in this Space")
 
 
 
 
 
 
 
259
 
260
  with gr.Row():
261
  model_id_input = gr.Textbox(
@@ -271,30 +288,32 @@ with gr.Blocks(theme=theme) as iface:
271
  dataset_config_input = gr.Dropdown(['default'], value='default', label='Dataset Config', visible=False)
272
  dataset_split_input = gr.Dropdown(['default'], value='default', label='Dataset Split', visible=False)
273
 
274
- dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
275
- dataset_config_input.change(
 
 
276
  check_dataset_and_get_split,
277
  inputs=[dataset_config_input, dataset_id_input],
278
  outputs=[dataset_split_input])
279
 
280
  with gr.Row(visible=True) as loading_row:
281
  gr.Markdown('''
282
- <h1 style="text-align: center;">
283
- Please validate your model and dataset first...
284
- </h1>
285
  ''')
286
-
287
  with gr.Row(visible=False) as preview_row:
288
  gr.Markdown('''
289
  <h1 style="text-align: center;">
290
- Confirm Label Details
291
  </h1>
292
- Base on your model and dataset, we inferred this label mapping. **If the mapping is incorrect, please modify it in the table below.**
293
  ''')
294
 
295
  with gr.Row():
296
  id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping", interactive=True, visible=False)
297
-
298
  with gr.Row():
299
  example_input = gr.Markdown('Sample Input: ', visible=False)
300
 
@@ -308,22 +327,30 @@ with gr.Blocks(theme=theme) as iface:
308
  size="lg",
309
  )
310
 
311
- model_id_input.change(gate_validate_btn,
312
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
313
- outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
314
- dataset_id_input.change(gate_validate_btn,
315
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
316
- outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
317
- dataset_config_input.change(gate_validate_btn,
318
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
319
- outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
320
- dataset_split_input.change(gate_validate_btn,
321
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
322
- outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
323
  id2label_mapping_dataframe.input(gate_validate_btn,
324
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, id2label_mapping_dataframe],
325
- outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
326
-
 
 
 
 
 
 
 
 
327
  run_btn.click(
328
  try_submit,
329
  inputs=[
@@ -332,6 +359,7 @@ with gr.Blocks(theme=theme) as iface:
332
  dataset_config_input,
333
  dataset_split_input,
334
  id2label_mapping_dataframe,
 
335
  run_local,
336
  ],
337
  outputs=[
 
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
  from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
14
+ from utils import read_scanners, write_scanners, read_inference_type, write_inference_type, convert_column_mapping_to_json
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
17
  HF_SPACE_ID = 'SPACE_ID'
18
  HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
19
 
 
20
  theme = gr.themes.Soft(
21
  primary_hue="green",
22
  )
 
69
  gr.update(visible=False), # Model prediction input
70
  gr.update(visible=False), # Model prediction preview
71
  gr.update(visible=False), # Label mapping preview
72
+ gr.update(visible=False), # feature mapping preview
73
  )
74
  if isinstance(ppl, Exception):
75
  gr.Warning(f'Failed to load model": {ppl}')
 
80
  gr.update(visible=False), # Model prediction input
81
  gr.update(visible=False), # Model prediction preview
82
  gr.update(visible=False), # Label mapping preview
83
+ gr.update(visible=False), # feature mapping preview
84
  )
85
 
86
  # Validate dataset
 
106
  gr.update(visible=False), # Model prediction input
107
  gr.update(visible=False), # Model prediction preview
108
  gr.update(visible=False), # Label mapping preview
109
+ gr.update(visible=False), # feature mapping preview
110
  )
111
 
112
  # TODO: Validate column mapping by running once
 
119
  except Exception:
120
  column_mapping = {}
121
 
122
+ column_mapping, prediction_input, prediction_result, id2label_df, feature_df = \
123
  text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
124
 
125
  column_mapping = json.dumps(column_mapping, indent=2)
126
 
127
+ if prediction_result is None and id2label_df is not None:
128
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
129
  return (
130
  gr.update(interactive=False), # Submit button
131
+ gr.update(visible=False), # Loading row
132
+ gr.update(visible=True), # Preview row
133
+ gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
134
  gr.update(visible=False), # Model prediction preview
135
+ gr.update(value=id2label_df, visible=True, interactive=True), # Label mapping preview
136
+ gr.update(value=feature_df, visible=True, interactive=True), # feature mapping preview
137
  )
138
  elif id2label_df is None:
139
  gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
 
143
  gr.update(visible=True), # Preview row
144
  gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
145
  gr.update(value=prediction_result, visible=True), # Model prediction preview
146
+ gr.update(visible=True, interactive=True), # Label mapping preview
147
+ gr.update(visible=True, interactive=True), # feature mapping preview
148
  )
149
 
150
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
 
156
  gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
157
  gr.update(value=prediction_result, visible=True), # Model prediction preview
158
  gr.update(value=id2label_df, visible=True, interactive=True), # Label mapping preview
159
+ gr.update(value=feature_df, visible=True, interactive=True), # feature mapping preview
160
  )
161
 
162
 
163
+ def try_submit(m_id, d_id, config, split, id2label_mapping_dataframe, feature_mapping_dataframe, local):
164
  label_mapping = {}
165
+ for i, label in id2label_mapping_dataframe["Model Prediction Labels"].items():
166
  label_mapping.update({str(i): label})
167
+
168
+ feature_mapping = {}
169
+ for i, feature in feature_mapping_dataframe["Dataset Features"].items():
170
+ feature_mapping.update({feature_mapping_dataframe["Model Input Features"][i]: feature})
171
 
172
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
173
 
 
184
  "--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
185
  "--output_format", "markdown",
186
  "--output_portal", "huggingface",
187
+ "--feature_mapping", json.dumps(feature_mapping),
188
  "--label_mapping", json.dumps(label_mapping),
189
+ "--scan_config", "./config.yaml",
190
  ]
191
 
192
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
 
228
  gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
229
  pass
230
 
231
+ def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None, feature_mapping_dataframe=None):
232
  column_mapping = '{}'
233
+ _, ppl = check_model(model_id=model_id)
234
 
235
  if id2label_mapping_dataframe is not None:
236
+ labels = convert_column_mapping_to_json(id2label_mapping_dataframe.value, label="data")
237
+ features = convert_column_mapping_to_json(feature_mapping_dataframe.value, label="text")
238
+ column_mapping = json.dumps({**labels, **features}, indent=2)
239
+
240
  if check_column_mapping_keys_validity(column_mapping, ppl) is False:
241
  gr.Warning('Label mapping table has invalid contents. Please check again.')
242
  return (gr.update(interactive=False),
 
244
  gr.update(),
245
  gr.update(),
246
  gr.update(),
247
+ gr.update(),
248
  gr.update())
249
  else:
250
  if model_id and dataset_id and dataset_config and dataset_split:
251
+ return try_validate(model_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping)
252
  else:
 
 
253
  return (gr.update(interactive=False),
254
  gr.update(visible=True),
255
  gr.update(visible=False),
256
  gr.update(visible=False),
257
  gr.update(visible=False),
258
+ gr.update(visible=False),
259
  gr.update(visible=False))
260
  with gr.Row():
261
  gr.Markdown('''
 
266
  ''')
267
  with gr.Row():
268
  run_local = gr.Checkbox(value=True, label="Run in this Space")
269
+ use_inference = read_inference_type('./config.yaml') == 'hf_inference_api'
270
+ run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API")
271
+
272
+ with gr.Row() as advanced_row:
273
+ selected = read_scanners('./config.yaml')
274
+ scan_config = selected + ['data_leakage']
275
+ scanners = gr.CheckboxGroup(choices=scan_config, value=selected, label='Scan Settings', visible=True)
276
 
277
  with gr.Row():
278
  model_id_input = gr.Textbox(
 
288
  dataset_config_input = gr.Dropdown(['default'], value='default', label='Dataset Config', visible=False)
289
  dataset_split_input = gr.Dropdown(['default'], value='default', label='Dataset Split', visible=False)
290
 
291
+ dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
292
+ dataset_id_input.submit(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
293
+
294
+ dataset_config_input.blur(
295
  check_dataset_and_get_split,
296
  inputs=[dataset_config_input, dataset_id_input],
297
  outputs=[dataset_split_input])
298
 
299
  with gr.Row(visible=True) as loading_row:
300
  gr.Markdown('''
301
+ <p style="text-align: center;">
302
+ 🚀🐢Please validate your model and dataset first...
303
+ </p>
304
  ''')
305
+
306
  with gr.Row(visible=False) as preview_row:
307
  gr.Markdown('''
308
  <h1 style="text-align: center;">
309
+ Confirm Pre-processing Details
310
  </h1>
311
+ Base on your model and dataset, we inferred this label mapping and feature mapping. <b>If the mapping is incorrect, please modify it in the table below.</b>
312
  ''')
313
 
314
  with gr.Row():
315
  id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping", interactive=True, visible=False)
316
+ feature_mapping_dataframe = gr.DataFrame(label="Preview of feature mapping", interactive=True, visible=False)
317
  with gr.Row():
318
  example_input = gr.Markdown('Sample Input: ', visible=False)
319
 
 
327
  size="lg",
328
  )
329
 
330
+ model_id_input.blur(gate_validate_btn,
331
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
332
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
333
+ dataset_id_input.blur(gate_validate_btn,
334
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
335
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
336
+ dataset_config_input.input(gate_validate_btn,
337
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
338
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
339
+ dataset_split_input.input(gate_validate_btn,
340
  inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
341
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
342
  id2label_mapping_dataframe.input(gate_validate_btn,
343
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, id2label_mapping_dataframe, feature_mapping_dataframe],
344
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
345
+ feature_mapping_dataframe.input(gate_validate_btn,
346
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, id2label_mapping_dataframe, feature_mapping_dataframe],
347
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
348
+ scanners.change(write_scanners, inputs=scanners)
349
+ run_inference.change(
350
+ write_inference_type,
351
+ inputs=[run_inference]
352
+ )
353
+
354
  run_btn.click(
355
  try_submit,
356
  inputs=[
 
359
  dataset_config_input,
360
  dataset_split_input,
361
  id2label_mapping_dataframe,
362
+ feature_mapping_dataframe,
363
  run_local,
364
  ],
365
  outputs=[
config.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ detectors:
2
+ - ethical_bias
3
+ - text_perturbation
4
+ - robustness
5
+ - performance
6
+ - underconfidence
7
+ - overconfidence
8
+ - spurious_correlation
9
+ inference_type: hf_pipeline
text_classification.py CHANGED
@@ -19,9 +19,8 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
19
  continue
20
  if len(feature.names) != len(id2label_mapping.keys()):
21
  continue
22
-
23
  dataset_labels = feature.names
24
-
25
  # Try to match labels
26
  for label in feature.names:
27
  if label in id2label_mapping.keys():
@@ -31,10 +30,23 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
31
  model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
32
  if model_label is not None:
33
  id2label_mapping[model_label] = label
 
 
34
 
35
  return id2label_mapping, dataset_labels
36
 
37
-
 
 
 
 
 
 
 
 
 
 
 
38
  def check_column_mapping_keys_validity(column_mapping, ppl):
39
  # get the element in all the list elements
40
  column_mapping = json.loads(column_mapping)
@@ -48,19 +60,10 @@ def check_column_mapping_keys_validity(column_mapping, ppl):
48
 
49
  return user_labels == model_labels == original_labels
50
 
51
-
52
- def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
53
- # We assume dataset is ok here
54
- ds = datasets.load_dataset(d_id, config)[split]
55
-
56
- try:
57
- dataset_features = ds.features
58
- except AttributeError:
59
- # Dataset does not have features, need to provide everything
60
- return None, None, None
61
-
62
  # Check whether we need to infer the text input column
63
  infer_text_input_column = True
 
64
  if "text" in column_mapping.keys():
65
  dataset_text_column = column_mapping["text"]
66
  if dataset_text_column in dataset_features.keys():
@@ -71,12 +74,26 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
71
  if infer_text_input_column:
72
  # Try to retrieve one
73
  candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
 
 
 
 
74
  if len(candidates) > 0:
75
  logging.debug(f"Candidates are {candidates}")
76
  column_mapping["text"] = candidates[0]
77
- else:
78
- # Not found a text feature
79
- return column_mapping, None, None
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Load dataset as DataFrame
82
  df = ds.to_pandas()
@@ -85,24 +102,13 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
85
  id2label_mapping = {}
86
  id2label = ppl.model.config.id2label
87
  label2id = {v: k for k, v in id2label.items()}
88
- prediction_input = None
89
- prediction_result = None
90
- try:
91
- # Use the first item to test prediction
92
- prediction_input = df.head(1).at[0, column_mapping["text"]]
93
- results = ppl({"text": prediction_input}, top_k=None)
94
- prediction_result = {
95
- f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
96
- }
97
- except Exception:
98
- # Pipeline prediction failed, need to provide labels
99
- return column_mapping, None, None
100
 
101
  # Infer labels
102
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
103
  id2label_mapping_dataset_model = {
104
  v: k for k, v in id2label_mapping.items()
105
  }
 
106
  if "data" in column_mapping.keys():
107
  if isinstance(column_mapping["data"], list):
108
  # Use the column mapping passed by user
@@ -112,15 +118,31 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
112
  column_mapping["label"] = {
113
  i: None for i in id2label.keys()
114
  }
115
- return column_mapping, prediction_result, None
116
 
117
- prediction_result = {
118
- f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results
119
- }
120
  id2label_df = pd.DataFrame({
121
  "Dataset Labels": dataset_labels,
122
  "Model Prediction Labels": [id2label_mapping_dataset_model[label] for label in dataset_labels],
123
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  if "data" not in column_mapping.keys():
126
  # Column mapping should contain original model labels
@@ -128,4 +150,4 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
128
  str(i): id2label_mapping_dataset_model[label] for i, label in zip(id2label.keys(), dataset_labels)
129
  }
130
 
131
- return column_mapping, prediction_input, prediction_result, id2label_df
 
19
  continue
20
  if len(feature.names) != len(id2label_mapping.keys()):
21
  continue
22
+
23
  dataset_labels = feature.names
 
24
  # Try to match labels
25
  for label in feature.names:
26
  if label in id2label_mapping.keys():
 
30
  model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
31
  if model_label is not None:
32
  id2label_mapping[model_label] = label
33
+ else:
34
+ print(f"Label {label} is not found in model labels")
35
 
36
  return id2label_mapping, dataset_labels
37
 
38
+ '''
39
+ params:
40
+ column_mapping: dict
41
+ example: {
42
+ "text": "sentences",
43
+ "label": {
44
+ "label0": "LABEL_0",
45
+ "label1": "LABEL_1"
46
+ }
47
+ }
48
+ ppl: pipeline
49
+ '''
50
  def check_column_mapping_keys_validity(column_mapping, ppl):
51
  # get the element in all the list elements
52
  column_mapping = json.loads(column_mapping)
 
60
 
61
  return user_labels == model_labels == original_labels
62
 
63
+ def infer_text_input_column(column_mapping, dataset_features):
 
 
 
 
 
 
 
 
 
 
64
  # Check whether we need to infer the text input column
65
  infer_text_input_column = True
66
+ feature_map_df = None
67
  if "text" in column_mapping.keys():
68
  dataset_text_column = column_mapping["text"]
69
  if dataset_text_column in dataset_features.keys():
 
74
  if infer_text_input_column:
75
  # Try to retrieve one
76
  candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
77
+ feature_map_df = pd.DataFrame({
78
+ "Dataset Features": [candidates[0]],
79
+ "Model Input Features": ["text"]
80
+ })
81
  if len(candidates) > 0:
82
  logging.debug(f"Candidates are {candidates}")
83
  column_mapping["text"] = candidates[0]
84
+
85
+ return column_mapping, feature_map_df
86
+
87
+ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
88
+ # We assume dataset is ok here
89
+ ds = datasets.load_dataset(d_id, config)[split]
90
+ try:
91
+ dataset_features = ds.features
92
+ except AttributeError:
93
+ # Dataset does not have features, need to provide everything
94
+ return None, None, None, None, None
95
+
96
+ column_mapping, feature_map_df = infer_text_input_column(column_mapping, dataset_features)
97
 
98
  # Load dataset as DataFrame
99
  df = ds.to_pandas()
 
102
  id2label_mapping = {}
103
  id2label = ppl.model.config.id2label
104
  label2id = {v: k for k, v in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Infer labels
107
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
108
  id2label_mapping_dataset_model = {
109
  v: k for k, v in id2label_mapping.items()
110
  }
111
+
112
  if "data" in column_mapping.keys():
113
  if isinstance(column_mapping["data"], list):
114
  # Use the column mapping passed by user
 
118
  column_mapping["label"] = {
119
  i: None for i in id2label.keys()
120
  }
121
+ return column_mapping, None, None, None, feature_map_df
122
 
 
 
 
123
  id2label_df = pd.DataFrame({
124
  "Dataset Labels": dataset_labels,
125
  "Model Prediction Labels": [id2label_mapping_dataset_model[label] for label in dataset_labels],
126
  })
127
+
128
+ # get a sample prediction from the model on the dataset
129
+ prediction_input = None
130
+ prediction_result = None
131
+ try:
132
+ # Use the first item to test prediction
133
+ prediction_input = df.head(1).at[0, column_mapping["text"]]
134
+ results = ppl({"text": prediction_input}, top_k=None)
135
+ prediction_result = {
136
+ f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
137
+ }
138
+ except Exception as e:
139
+ # Pipeline prediction failed, need to provide labels
140
+ print(e, '>>>> error')
141
+ return column_mapping, prediction_input, None, id2label_df, feature_map_df
142
+
143
+ prediction_result = {
144
+ f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results
145
+ }
146
 
147
  if "data" not in column_mapping.keys():
148
  # Column mapping should contain original model labels
 
150
  str(i): id2label_mapping_dataset_model[label] for i, label in zip(id2label.keys(), dataset_labels)
151
  }
152
 
153
+ return column_mapping, prediction_input, prediction_result, id2label_df, feature_map_df
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ YAML_PATH = "./config.yaml"
4
+
5
+ class Dumper(yaml.Dumper):
6
+ def increase_indent(self, flow=False, *args, **kwargs):
7
+ return super().increase_indent(flow=flow, indentless=False)
8
+
9
+ # read scanners from yaml file
10
+ # return a list of scanners
11
+ def read_scanners(path):
12
+ scanners = []
13
+ with open(path, "r") as f:
14
+ config = yaml.load(f, Loader=yaml.FullLoader)
15
+ scanners = config.get("detectors", None)
16
+ return scanners
17
+
18
+ # convert a list of scanners to yaml file
19
+ def write_scanners(scanners):
20
+ with open(YAML_PATH, "r") as f:
21
+ config = yaml.load(f, Loader=yaml.FullLoader)
22
+
23
+ config["detectors"] = scanners
24
+ with open(YAML_PATH, "w") as f:
25
+ # save scanners to detectors in yaml
26
+ yaml.dump(config, f, Dumper=Dumper)
27
+
28
+ # read model_type from yaml file
29
+ def read_inference_type(path):
30
+ inference_type = ""
31
+ with open(path, "r") as f:
32
+ config = yaml.load(f, Loader=yaml.FullLoader)
33
+ inference_type = config.get("inference_type", None)
34
+ return inference_type
35
+
36
+ # write model_type to yaml file
37
+ def write_inference_type(use_inference):
38
+ with open(YAML_PATH, "r") as f:
39
+ config = yaml.load(f, Loader=yaml.FullLoader)
40
+ if use_inference:
41
+ config["inference_type"] = 'hf_inference_api'
42
+ else:
43
+ config["inference_type"] = 'hf_pipeline'
44
+ with open(YAML_PATH, "w") as f:
45
+ # save inference_type to inference_type in yaml
46
+ yaml.dump(config, f, Dumper=Dumper)
47
+
48
+ # convert column mapping dataframe to json
49
+ def convert_column_mapping_to_json(df, label=""):
50
+ column_mapping = {}
51
+ column_mapping[label] = []
52
+ for _, row in df.iterrows():
53
+ column_mapping[label].append(row.tolist())
54
+ return column_mapping