inoki-giskard ZeroCommand commited on
Commit
5b8d6d5
1 Parent(s): 134469e

GSK-2509 fix not standard label columns (go_emotions) (#29)

Browse files

- add conditions to extract labels from dataset (d81d6fd28b4998c5d7a4ae2bb505594f3cc9dfbd)
- Merge branch 'main' into pr/29 (bedf925c2a667142d0a1c7250987c84e0ad615d3)
- Fix for flattened raw config (21e0bb3cab9bc33333e7495856224aaff1f571fa)
- move inference api parameters (fc7c452cd6a03999b2cd703dd4ac986f1521e5da)
- add predict button (44ab78abaa1f08d260528e29369f023f6188e9cc)


Co-authored-by: zcy <ZeroCommand@users.noreply.huggingface.co>

app.py CHANGED
@@ -10,7 +10,7 @@ from run_jobs import start_process_run_job, stop_thread
10
  try:
11
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
12
  with gr.Tab("Text Classification"):
13
- get_demo_text_classification(demo)
14
  with gr.Tab("Leaderboard"):
15
  get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
 
10
  try:
11
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
12
  with gr.Tab("Text Classification"):
13
+ get_demo_text_classification()
14
  with gr.Tab("Leaderboard"):
15
  get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
app_text_classification.py CHANGED
@@ -2,17 +2,17 @@ import uuid
2
 
3
  import gradio as gr
4
 
5
- from io_utils import (get_logs_file, read_inference_type, read_scanners,
6
- write_inference_type, write_scanners)
7
  from text_classification_ui_helpers import (check_dataset_and_get_config,
8
  check_dataset_and_get_split,
9
- check_model_and_show_prediction,
10
  deselect_run_inference,
11
  select_run_mode, try_submit,
12
- write_column_mapping_to_config)
 
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
- MAX_LABELS = 20
16
  MAX_FEATURES = 20
17
 
18
  EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
@@ -20,7 +20,7 @@ EXAMPLE_DATA_ID = "tweet_eval"
20
  CONFIG_PATH = "./config.yaml"
21
 
22
 
23
- def get_demo(demo):
24
  with gr.Row():
25
  gr.Markdown(INTRODUCTION_MD)
26
  uid_label = gr.Textbox(
@@ -41,6 +41,13 @@ def get_demo(demo):
41
  dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
42
  dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
43
 
 
 
 
 
 
 
 
44
  with gr.Row():
45
  example_input = gr.HTML(visible=False)
46
  with gr.Row():
@@ -55,23 +62,17 @@ def get_demo(demo):
55
  column_mappings = []
56
  with gr.Row():
57
  with gr.Column():
 
58
  for _ in range(MAX_LABELS):
59
  column_mappings.append(gr.Dropdown(visible=False))
60
  with gr.Column():
 
61
  for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
62
  column_mappings.append(gr.Dropdown(visible=False))
63
 
64
  with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
65
  run_local = gr.Checkbox(value=True, label="Run in this Space")
66
- run_inference = gr.Checkbox(value="False", label="Run with Inference API")
67
-
68
- @gr.on(triggers=[uid_label.change], inputs=[uid_label], outputs=[run_inference])
69
- def get_run_mode(uid):
70
- return gr.update(
71
- value=read_inference_type(uid) == "hf_inference_api"
72
- and not run_local.value
73
- )
74
-
75
  inference_token = gr.Textbox(
76
  value="",
77
  label="HF Token for Inference API",
@@ -97,7 +98,7 @@ def get_demo(demo):
97
  run_btn = gr.Button(
98
  "Get Evaluation Result",
99
  variant="primary",
100
- interactive=True,
101
  size="lg",
102
  )
103
 
@@ -120,7 +121,7 @@ def get_demo(demo):
120
 
121
  run_inference.change(
122
  select_run_mode,
123
- inputs=[run_inference, inference_token, uid_label],
124
  outputs=[inference_token, run_local],
125
  )
126
 
@@ -130,17 +131,10 @@ def get_demo(demo):
130
  outputs=[inference_token, run_inference],
131
  )
132
 
133
- inference_token.change(
134
- write_inference_type, inputs=[run_inference, inference_token, uid_label]
135
- )
136
-
137
  gr.on(
138
  triggers=[label.change for label in column_mappings],
139
  fn=write_column_mapping_to_config,
140
  inputs=[
141
- dataset_id_input,
142
- dataset_config_input,
143
- dataset_split_input,
144
  uid_label,
145
  *column_mappings,
146
  ],
@@ -151,9 +145,6 @@ def get_demo(demo):
151
  triggers=[label.input for label in column_mappings],
152
  fn=write_column_mapping_to_config,
153
  inputs=[
154
- dataset_id_input,
155
- dataset_config_input,
156
- dataset_split_input,
157
  uid_label,
158
  *column_mappings,
159
  ],
@@ -164,19 +155,33 @@ def get_demo(demo):
164
  model_id_input.change,
165
  dataset_id_input.change,
166
  dataset_config_input.change,
167
- dataset_split_input.change,
 
 
 
 
 
 
 
 
 
 
 
 
168
  ],
169
- fn=check_model_and_show_prediction,
170
  inputs=[
171
  model_id_input,
172
  dataset_id_input,
173
  dataset_config_input,
174
  dataset_split_input,
 
175
  ],
176
  outputs=[
177
  example_input,
178
  example_prediction,
179
  column_mapping_accordion,
 
180
  *column_mappings,
181
  ],
182
  )
@@ -192,6 +197,8 @@ def get_demo(demo):
192
  dataset_config_input,
193
  dataset_split_input,
194
  run_local,
 
 
195
  uid_label,
196
  ],
197
  outputs=[run_btn, logs],
@@ -202,12 +209,10 @@ def get_demo(demo):
202
 
203
  gr.on(
204
  triggers=[
205
- model_id_input.change,
206
- dataset_config_input.change,
207
- dataset_split_input.change,
208
- run_inference.change,
209
- run_local.change,
210
- scanners.change,
211
  ],
212
  fn=enable_run_btn,
213
  inputs=None,
@@ -215,8 +220,8 @@ def get_demo(demo):
215
  )
216
 
217
  gr.on(
218
- triggers=[label.change for label in column_mappings],
219
  fn=enable_run_btn,
220
- inputs=None,
221
  outputs=[run_btn],
222
  )
 
2
 
3
  import gradio as gr
4
 
5
+ from io_utils import (get_logs_file, read_scanners, write_scanners)
 
6
  from text_classification_ui_helpers import (check_dataset_and_get_config,
7
  check_dataset_and_get_split,
8
+ align_columns_and_show_prediction,
9
  deselect_run_inference,
10
  select_run_mode, try_submit,
11
+ write_column_mapping_to_config,
12
+ precheck_model_ds_enable_example_btn)
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
+ MAX_LABELS = 40
16
  MAX_FEATURES = 20
17
 
18
  EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
 
20
  CONFIG_PATH = "./config.yaml"
21
 
22
 
23
+ def get_demo():
24
  with gr.Row():
25
  gr.Markdown(INTRODUCTION_MD)
26
  uid_label = gr.Textbox(
 
41
  dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
42
  dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
43
 
44
+ with gr.Row():
45
+ example_btn = gr.Button(
46
+ "Auto-align Columns & Get Sample Prediction",
47
+ visible=True,
48
+ variant="primary",
49
+ interactive=False)
50
+
51
  with gr.Row():
52
  example_input = gr.HTML(visible=False)
53
  with gr.Row():
 
62
  column_mappings = []
63
  with gr.Row():
64
  with gr.Column():
65
+ gr.Markdown("# Label Mapping")
66
  for _ in range(MAX_LABELS):
67
  column_mappings.append(gr.Dropdown(visible=False))
68
  with gr.Column():
69
+ gr.Markdown("# Feature Mapping")
70
  for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
71
  column_mappings.append(gr.Dropdown(visible=False))
72
 
73
  with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
74
  run_local = gr.Checkbox(value=True, label="Run in this Space")
75
+ run_inference = gr.Checkbox(value=False, label="Run with Inference API")
 
 
 
 
 
 
 
 
76
  inference_token = gr.Textbox(
77
  value="",
78
  label="HF Token for Inference API",
 
98
  run_btn = gr.Button(
99
  "Get Evaluation Result",
100
  variant="primary",
101
+ interactive=False,
102
  size="lg",
103
  )
104
 
 
121
 
122
  run_inference.change(
123
  select_run_mode,
124
+ inputs=[run_inference],
125
  outputs=[inference_token, run_local],
126
  )
127
 
 
131
  outputs=[inference_token, run_inference],
132
  )
133
 
 
 
 
 
134
  gr.on(
135
  triggers=[label.change for label in column_mappings],
136
  fn=write_column_mapping_to_config,
137
  inputs=[
 
 
 
138
  uid_label,
139
  *column_mappings,
140
  ],
 
145
  triggers=[label.input for label in column_mappings],
146
  fn=write_column_mapping_to_config,
147
  inputs=[
 
 
 
148
  uid_label,
149
  *column_mappings,
150
  ],
 
155
  model_id_input.change,
156
  dataset_id_input.change,
157
  dataset_config_input.change,
158
+ dataset_split_input.change],
159
+ fn=precheck_model_ds_enable_example_btn,
160
+ inputs=[
161
+ model_id_input,
162
+ dataset_id_input,
163
+ dataset_config_input,
164
+ dataset_split_input,
165
+ ],
166
+ outputs=[example_btn])
167
+
168
+ gr.on(
169
+ triggers=[
170
+ example_btn.click,
171
  ],
172
+ fn=align_columns_and_show_prediction,
173
  inputs=[
174
  model_id_input,
175
  dataset_id_input,
176
  dataset_config_input,
177
  dataset_split_input,
178
+ uid_label,
179
  ],
180
  outputs=[
181
  example_input,
182
  example_prediction,
183
  column_mapping_accordion,
184
+ run_btn,
185
  *column_mappings,
186
  ],
187
  )
 
197
  dataset_config_input,
198
  dataset_split_input,
199
  run_local,
200
+ run_inference,
201
+ inference_token,
202
  uid_label,
203
  ],
204
  outputs=[run_btn, logs],
 
209
 
210
  gr.on(
211
  triggers=[
212
+ run_inference.input,
213
+ run_local.input,
214
+ inference_token.input,
215
+ scanners.input,
 
 
216
  ],
217
  fn=enable_run_btn,
218
  inputs=None,
 
220
  )
221
 
222
  gr.on(
223
+ triggers=[label.input for label in column_mappings],
224
  fn=enable_run_btn,
225
+ inputs=column_mappings,
226
  outputs=[run_btn],
227
  )
io_utils.py CHANGED
@@ -76,7 +76,6 @@ def read_column_mapping(uid):
76
  config = yaml.load(f, Loader=yaml.FullLoader)
77
  if config:
78
  column_mapping = config.get("column_mapping", dict())
79
- f.close()
80
  return column_mapping
81
 
82
 
@@ -84,7 +83,6 @@ def read_column_mapping(uid):
84
  def write_column_mapping(mapping, uid):
85
  with open(get_yaml_path(uid), "r") as f:
86
  config = yaml.load(f, Loader=yaml.FullLoader)
87
- f.close()
88
 
89
  if config is None:
90
  return
@@ -92,10 +90,9 @@ def write_column_mapping(mapping, uid):
92
  del config["column_mapping"]
93
  else:
94
  config["column_mapping"] = mapping
95
-
96
  with open(get_yaml_path(uid), "w") as f:
97
- yaml.dump(config, f, Dumper=Dumper)
98
- f.close()
99
 
100
 
101
  # convert column mapping dataframe to json
 
76
  config = yaml.load(f, Loader=yaml.FullLoader)
77
  if config:
78
  column_mapping = config.get("column_mapping", dict())
 
79
  return column_mapping
80
 
81
 
 
83
  def write_column_mapping(mapping, uid):
84
  with open(get_yaml_path(uid), "r") as f:
85
  config = yaml.load(f, Loader=yaml.FullLoader)
 
86
 
87
  if config is None:
88
  return
 
90
  del config["column_mapping"]
91
  else:
92
  config["column_mapping"] = mapping
 
93
  with open(get_yaml_path(uid), "w") as f:
94
+ # yaml Dumper will by default sort the keys
95
+ yaml.dump(config, f, Dumper=Dumper, sort_keys=False)
96
 
97
 
98
  # convert column mapping dataframe to json
text_classification.py CHANGED
@@ -15,8 +15,17 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
15
  try:
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
- labels = dataset_features["label"].names
19
- features = [f for f in dataset_features.keys() if f != "label"]
 
 
 
 
 
 
 
 
 
20
  return labels, features
21
  except Exception as e:
22
  logging.warning(
 
15
  try:
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
+ label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
19
+ if len(label_keys) == 0: # no labels found
20
+ # return everything for post processing
21
+ return list(dataset_features.keys()), list(dataset_features.keys())
22
+ if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
23
+ if hasattr(dataset_features[label_keys[0]], 'feature'):
24
+ label_feat = dataset_features[label_keys[0]].feature
25
+ labels = label_feat.names
26
+ else:
27
+ labels = [dataset_features[label_keys[0]].names]
28
+ features = [f for f in dataset_features.keys() if not f.startswith("label")]
29
  return labels, features
30
  except Exception as e:
31
  logging.warning(
text_classification_ui_helpers.py CHANGED
@@ -10,7 +10,7 @@ from transformers.pipelines import TextClassificationPipeline
10
  from wordings import get_styled_input
11
 
12
  from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe,
13
- write_column_mapping, write_inference_type,
14
  write_log_to_user_file)
15
  from text_classification import (check_model, get_example_prediction,
16
  get_labels_and_features_from_dataset)
@@ -18,7 +18,7 @@ from wordings import (CHECK_CONFIG_OR_SPLIT_RAW,
18
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
19
  MAPPING_STYLED_ERROR_WARNING)
20
 
21
- MAX_LABELS = 20
22
  MAX_FEATURES = 20
23
 
24
  HF_REPO_ID = "HF_REPO_ID"
@@ -51,10 +51,8 @@ def check_dataset_and_get_split(dataset_id, dataset_config):
51
  pass
52
 
53
 
54
- def select_run_mode(run_inf, inf_token, uid):
55
  if run_inf:
56
- if len(inf_token) > 0:
57
- write_inference_type(run_inf, inf_token, uid)
58
  return (gr.update(visible=True), gr.update(value=False))
59
  else:
60
  return (gr.update(visible=False), gr.update(value=True))
@@ -68,46 +66,62 @@ def deselect_run_inference(run_local):
68
 
69
 
70
  def write_column_mapping_to_config(
71
- dataset_id, dataset_config, dataset_split, uid, *labels
72
  ):
73
  # TODO: Substitute 'text' with more features for zero-shot
74
  # we are not using ds features because we only support "text" for now
75
- ds_labels, _ = get_labels_and_features_from_dataset(
76
- dataset_id, dataset_config, dataset_split
77
- )
78
  if labels is None:
79
  return
 
 
80
 
81
- all_mappings = dict()
82
-
83
- if "labels" not in all_mappings.keys():
84
- all_mappings["labels"] = dict()
85
- for i, label in enumerate(labels[:MAX_LABELS]):
86
- if label:
87
- all_mappings["labels"][label] = ds_labels[i % len(ds_labels)]
88
- if "features" not in all_mappings.keys():
89
- all_mappings["features"] = dict()
90
- for _, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]):
91
- if feat:
92
- # TODO: Substitute 'text' with more features for zero-shot
93
- all_mappings["features"]["text"] = feat
94
  write_column_mapping(all_mappings, uid)
95
 
96
-
97
- def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  model_labels = list(model_id2label.values())
99
- len_model_labels = len(model_labels)
 
 
 
 
 
 
 
 
 
 
 
 
100
  lables = [
101
  gr.Dropdown(
102
  label=f"{label}",
103
  choices=model_labels,
104
- value=model_id2label[i % len_model_labels],
105
  interactive=True,
106
  visible=True,
107
  )
108
- for i, label in enumerate(ds_labels[:MAX_LABELS])
109
  ]
110
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
 
 
111
  # TODO: Substitute 'text' with more features for zero-shot
112
  features = [
113
  gr.Dropdown(
@@ -122,11 +136,27 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
122
  features += [
123
  gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
124
  ]
 
 
 
125
  return lables + features
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- def check_model_and_show_prediction(
129
- model_id, dataset_id, dataset_config, dataset_split
130
  ):
131
  ppl = check_model(model_id)
132
  if ppl is None or not isinstance(ppl, TextClassificationPipeline):
@@ -134,6 +164,8 @@ def check_model_and_show_prediction(
134
  return (
135
  gr.update(visible=False),
136
  gr.update(visible=False),
 
 
137
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
138
  )
139
 
@@ -147,6 +179,7 @@ def check_model_and_show_prediction(
147
  gr.update(visible=False),
148
  gr.update(visible=False),
149
  gr.update(visible=False, open=False),
 
150
  *dropdown_placement,
151
  )
152
  model_id2label = ppl.model.config.id2label
@@ -161,6 +194,7 @@ def check_model_and_show_prediction(
161
  gr.update(visible=False),
162
  gr.update(visible=False),
163
  gr.update(visible=False, open=False),
 
164
  *dropdown_placement,
165
  )
166
 
@@ -168,6 +202,7 @@ def check_model_and_show_prediction(
168
  ds_labels,
169
  ds_features,
170
  model_id2label,
 
171
  )
172
 
173
  # when labels or features are not aligned
@@ -180,6 +215,7 @@ def check_model_and_show_prediction(
180
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
181
  gr.update(visible=False),
182
  gr.update(visible=True, open=True),
 
183
  *column_mappings,
184
  )
185
 
@@ -190,13 +226,11 @@ def check_model_and_show_prediction(
190
  gr.update(value=get_styled_input(prediction_input), visible=True),
191
  gr.update(value=prediction_output, visible=True),
192
  gr.update(visible=True, open=False),
 
193
  *column_mappings,
194
  )
195
 
196
-
197
- def try_submit(m_id, d_id, config, split, local, uid):
198
- all_mappings = read_column_mapping(uid)
199
-
200
  if all_mappings is None:
201
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
202
  return (gr.update(interactive=True), gr.update(visible=False))
@@ -204,6 +238,8 @@ def try_submit(m_id, d_id, config, split, local, uid):
204
  if "labels" not in all_mappings.keys():
205
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
206
  return (gr.update(interactive=True), gr.update(visible=False))
 
 
207
  label_mapping = {}
208
  for i, label in zip(
209
  range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
@@ -214,73 +250,88 @@ def try_submit(m_id, d_id, config, split, local, uid):
214
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
215
  return (gr.update(interactive=True), gr.update(visible=False))
216
  feature_mapping = all_mappings["features"]
 
 
 
 
 
 
217
 
218
  leaderboard_dataset = None
219
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
220
  leaderboard_dataset = "ZeroCommand/test-giskard-report"
 
 
 
 
 
221
 
222
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
223
- if local:
224
- command = [
225
- "giskard_scanner",
226
- "--loader",
227
- "huggingface",
228
- "--model",
229
- m_id,
230
- "--dataset",
231
- d_id,
232
- "--dataset_config",
233
- config,
234
- "--dataset_split",
235
- split,
236
- "--hf_token",
237
- os.environ.get(HF_WRITE_TOKEN),
238
- "--discussion_repo",
239
- os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
240
- "--output_format",
241
- "markdown",
242
- "--output_portal",
243
- "huggingface",
244
- "--feature_mapping",
245
- json.dumps(feature_mapping),
246
- "--label_mapping",
247
- json.dumps(label_mapping),
248
- "--scan_config",
249
- get_yaml_path(uid),
250
- "--leaderboard_dataset",
251
- leaderboard_dataset,
252
- ]
253
- if os.environ.get(HF_GSK_HUB_KEY):
254
- command.append("--giskard_hub_api_key")
255
- command.append(os.environ.get(HF_GSK_HUB_KEY))
256
- if os.environ.get(HF_GSK_HUB_URL):
257
- command.append("--giskard_hub_url")
258
- command.append(os.environ.get(HF_GSK_HUB_URL))
259
- if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
260
- command.append("--giskard_hub_project_key")
261
- command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
262
- if os.environ.get(HF_GSK_HUB_HF_TOKEN):
263
- command.append("--giskard_hub_hf_token")
264
- command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
265
- if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
266
- command.append("--giskard_hub_unlock_token")
267
- command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))
268
-
269
- eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
270
- logging.info(f"Start local evaluation on {eval_str}")
271
- save_job_to_pipe(uid, command, eval_str, threading.Lock())
272
- write_log_to_user_file(
273
- uid,
274
- f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
275
- )
276
- gr.Info(f"Start local evaluation on {eval_str}")
 
 
 
 
277
 
278
- return (
279
- gr.update(interactive=False),
280
- gr.update(lines=5, visible=True, interactive=False),
281
- )
282
 
283
- else:
284
- gr.Info("TODO: Submit task to an endpoint")
285
 
286
- return (gr.update(interactive=True), gr.update(visible=False)) # Submit button
 
 
 
10
  from wordings import get_styled_input
11
 
12
  from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe,
13
+ write_column_mapping,
14
  write_log_to_user_file)
15
  from text_classification import (check_model, get_example_prediction,
16
  get_labels_and_features_from_dataset)
 
18
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
19
  MAPPING_STYLED_ERROR_WARNING)
20
 
21
+ MAX_LABELS = 40
22
  MAX_FEATURES = 20
23
 
24
  HF_REPO_ID = "HF_REPO_ID"
 
51
  pass
52
 
53
 
54
+ def select_run_mode(run_inf):
55
  if run_inf:
 
 
56
  return (gr.update(visible=True), gr.update(value=False))
57
  else:
58
  return (gr.update(visible=False), gr.update(value=True))
 
66
 
67
 
68
  def write_column_mapping_to_config(
69
+ uid, *labels
70
  ):
71
  # TODO: Substitute 'text' with more features for zero-shot
72
  # we are not using ds features because we only support "text" for now
73
+ all_mappings = read_column_mapping(uid)
74
+
 
75
  if labels is None:
76
  return
77
+ all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
78
+ all_mappings = export_mappings(all_mappings, "features", ["text"], labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)])
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  write_column_mapping(all_mappings, uid)
81
 
82
+ def export_mappings(all_mappings, key, subkeys, values):
83
+ if key not in all_mappings.keys():
84
+ all_mappings[key] = dict()
85
+ if subkeys is None:
86
+ subkeys = list(all_mappings[key].keys())
87
+
88
+ if not subkeys:
89
+ logging.debug(f"subkeys is empty for {key}")
90
+ return all_mappings
91
+
92
+ for i, subkey in enumerate(subkeys):
93
+ if subkey:
94
+ all_mappings[key][subkey] = values[i % len(values)]
95
+ return all_mappings
96
+
97
+ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid):
98
  model_labels = list(model_id2label.values())
99
+ all_mappings = read_column_mapping(uid)
100
+ # For flattened raw datasets with no labels
101
+ # check if there are shared labels between model and dataset
102
+ shared_labels = set(model_labels).intersection(set(ds_labels))
103
+ if shared_labels:
104
+ ds_labels = list(shared_labels)
105
+ if len(ds_labels) > MAX_LABELS:
106
+ ds_labels = ds_labels[:MAX_LABELS]
107
+ gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
108
+
109
+ ds_labels.sort()
110
+ model_labels.sort()
111
+
112
  lables = [
113
  gr.Dropdown(
114
  label=f"{label}",
115
  choices=model_labels,
116
+ value=model_id2label[i % len(model_labels)],
117
  interactive=True,
118
  visible=True,
119
  )
120
+ for i, label in enumerate(ds_labels)
121
  ]
122
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
123
+ all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)
124
+
125
  # TODO: Substitute 'text' with more features for zero-shot
126
  features = [
127
  gr.Dropdown(
 
136
  features += [
137
  gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
138
  ]
139
+ all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
140
+ write_column_mapping(all_mappings, uid)
141
+
142
  return lables + features
143
 
144
+ def precheck_model_ds_enable_example_btn(model_id, dataset_id, dataset_config, dataset_split):
145
+ ppl = check_model(model_id)
146
+ if ppl is None or not isinstance(ppl, TextClassificationPipeline):
147
+ gr.Warning("Please check your model.")
148
+ return gr.update(interactive=False)
149
+ ds_labels, ds_features = get_labels_and_features_from_dataset(
150
+ dataset_id, dataset_config, dataset_split
151
+ )
152
+ if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
153
+ gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
154
+ return gr.update(interactive=False)
155
+
156
+ return gr.update(interactive=True)
157
 
158
+ def align_columns_and_show_prediction(
159
+ model_id, dataset_id, dataset_config, dataset_split, uid
160
  ):
161
  ppl = check_model(model_id)
162
  if ppl is None or not isinstance(ppl, TextClassificationPipeline):
 
164
  return (
165
  gr.update(visible=False),
166
  gr.update(visible=False),
167
+ gr.update(visible=False, open=False),
168
+ gr.update(interactive=False),
169
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
170
  )
171
 
 
179
  gr.update(visible=False),
180
  gr.update(visible=False),
181
  gr.update(visible=False, open=False),
182
+ gr.update(interactive=False),
183
  *dropdown_placement,
184
  )
185
  model_id2label = ppl.model.config.id2label
 
194
  gr.update(visible=False),
195
  gr.update(visible=False),
196
  gr.update(visible=False, open=False),
197
+ gr.update(interactive=False),
198
  *dropdown_placement,
199
  )
200
 
 
202
  ds_labels,
203
  ds_features,
204
  model_id2label,
205
+ uid,
206
  )
207
 
208
  # when labels or features are not aligned
 
215
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
216
  gr.update(visible=False),
217
  gr.update(visible=True, open=True),
218
+ gr.update(interactive=True),
219
  *column_mappings,
220
  )
221
 
 
226
  gr.update(value=get_styled_input(prediction_input), visible=True),
227
  gr.update(value=prediction_output, visible=True),
228
  gr.update(visible=True, open=False),
229
+ gr.update(interactive=True),
230
  *column_mappings,
231
  )
232
 
233
+ def check_column_mapping_keys_validity(all_mappings):
 
 
 
234
  if all_mappings is None:
235
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
236
  return (gr.update(interactive=True), gr.update(visible=False))
 
238
  if "labels" not in all_mappings.keys():
239
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
240
  return (gr.update(interactive=True), gr.update(visible=False))
241
+
242
+ def construct_label_and_feature_mapping(all_mappings):
243
  label_mapping = {}
244
  for i, label in zip(
245
  range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
 
250
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
251
  return (gr.update(interactive=True), gr.update(visible=False))
252
  feature_mapping = all_mappings["features"]
253
+ return label_mapping, feature_mapping
254
+
255
+ def try_submit(m_id, d_id, config, split, local, inference, inference_token, uid):
256
+ all_mappings = read_column_mapping(uid)
257
+ check_column_mapping_keys_validity(all_mappings)
258
+ label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings)
259
 
260
  leaderboard_dataset = None
261
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
262
  leaderboard_dataset = "ZeroCommand/test-giskard-report"
263
+
264
+ if local:
265
+ inference_type = "hf_pipeline"
266
+ if inference and inference_token:
267
+ inference_type = "hf_inference_api"
268
 
269
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
270
+ command = [
271
+ "giskard_scanner",
272
+ "--loader",
273
+ "huggingface",
274
+ "--model",
275
+ m_id,
276
+ "--dataset",
277
+ d_id,
278
+ "--dataset_config",
279
+ config,
280
+ "--dataset_split",
281
+ split,
282
+ "--hf_token",
283
+ os.environ.get(HF_WRITE_TOKEN),
284
+ "--discussion_repo",
285
+ os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
286
+ "--output_format",
287
+ "markdown",
288
+ "--output_portal",
289
+ "huggingface",
290
+ "--feature_mapping",
291
+ json.dumps(feature_mapping),
292
+ "--label_mapping",
293
+ json.dumps(label_mapping),
294
+ "--scan_config",
295
+ get_yaml_path(uid),
296
+ "--leaderboard_dataset",
297
+ leaderboard_dataset,
298
+ "--inference_type",
299
+ inference_type,
300
+ "--inference_token",
301
+ inference_token,
302
+ ]
303
+ if os.environ.get(HF_GSK_HUB_KEY):
304
+ command.append("--giskard_hub_api_key")
305
+ command.append(os.environ.get(HF_GSK_HUB_KEY))
306
+ if os.environ.get(HF_GSK_HUB_URL):
307
+ command.append("--giskard_hub_url")
308
+ command.append(os.environ.get(HF_GSK_HUB_URL))
309
+ if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
310
+ command.append("--giskard_hub_project_key")
311
+ command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
312
+ if os.environ.get(HF_GSK_HUB_HF_TOKEN):
313
+ command.append("--giskard_hub_hf_token")
314
+ command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
315
+ if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
316
+ command.append("--giskard_hub_unlock_token")
317
+ command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))
318
+
319
+ eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
320
+ logging.info(f"Start local evaluation on {eval_str}")
321
+ save_job_to_pipe(uid, command, eval_str, threading.Lock())
322
+ print(command)
323
+ write_log_to_user_file(
324
+ uid,
325
+ f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
326
+ )
327
+ gr.Info(f"Start local evaluation on {eval_str}")
328
 
329
+ return (
330
+ gr.update(interactive=False),
331
+ gr.update(lines=5, visible=True, interactive=False),
332
+ )
333
 
 
 
334
 
335
+ # TODO: Submit task to an endpoint")
336
+
337
+ # return (gr.update(interactive=True), gr.update(visible=False)) # Submit button