ZeroCommand commited on
Commit
461883a
1 Parent(s): e7115eb

remove pipeline and improve events trigger

Browse files
app_text_classification.py CHANGED
@@ -8,11 +8,10 @@ from text_classification_ui_helpers import (
8
  align_columns_and_show_prediction,
9
  check_dataset,
10
  precheck_model_ds_enable_example_btn,
11
- select_run_mode,
12
  try_submit,
13
  write_column_mapping_to_config,
14
  )
15
- from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
16
 
17
  MAX_LABELS = 40
18
  MAX_FEATURES = 20
@@ -80,30 +79,9 @@ def get_demo():
80
  column_mappings.append(gr.Dropdown(visible=False))
81
 
82
  with gr.Accordion(label="Model Wrap Advance Config", open=True):
83
- run_inference = gr.Checkbox(
84
- value=True,
85
- label="Run with HF Inference API"
86
- )
87
- gr.HTML(
88
- value="""
89
- We recommend to use
90
- <a href="https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task">
91
- Hugging Face Inference API
92
- </a>
93
- for the evaluation,
94
- which requires your <a href="https://huggingface.co/settings/tokens">HF token</a>.
95
- <br/>
96
- Otherwise, an
97
- <a href="https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.TextClassificationPipeline">
98
- HF pipeline
99
- </a>
100
- will be created and run in this Space. It takes more time to get the result.
101
- <br/>
102
- <b>
103
- Do not worry, your HF token is only used in this Space for your evaluation.
104
- </b>
105
- """,
106
- )
107
  inference_token = gr.Textbox(
108
  placeholder="hf-xxxxxxxxxxxxxxxxxxxx",
109
  value="",
@@ -112,7 +90,6 @@ def get_demo():
112
  interactive=True,
113
  )
114
 
115
-
116
  with gr.Accordion(label="Scanner Advance Config (optional)", open=False):
117
  scanners = gr.CheckboxGroup(label="Scan Settings", visible=True)
118
 
@@ -143,37 +120,21 @@ def get_demo():
143
  every=0.5,
144
  )
145
 
146
- dataset_id_input.change(
147
- check_dataset,
148
- inputs=[dataset_id_input],
149
- outputs=[dataset_config_input, dataset_split_input, first_line_ds, loading_status],
150
- )
151
-
152
- dataset_config_input.change(
153
- check_dataset,
154
- inputs=[dataset_id_input, dataset_config_input],
155
- outputs=[dataset_config_input, dataset_split_input, first_line_ds, loading_status],
156
- )
157
-
158
- dataset_split_input.change(
159
- check_dataset,
160
- inputs=[dataset_id_input, dataset_config_input, dataset_split_input],
161
- outputs=[dataset_config_input, dataset_split_input, first_line_ds, loading_status],
162
- )
163
-
164
  scanners.change(write_scanners, inputs=[scanners, uid_label])
165
 
166
- run_inference.change(
167
- select_run_mode,
168
- inputs=[run_inference],
169
- outputs=[inference_token],
170
- )
171
-
172
  gr.on(
173
  triggers=[model_id_input.change],
174
  fn=get_related_datasets_from_leaderboard,
175
  inputs=[model_id_input],
176
  outputs=[dataset_id_input],
 
 
 
 
 
 
 
177
  )
178
 
179
  gr.on(
@@ -209,7 +170,7 @@ def get_demo():
209
  dataset_config_input,
210
  dataset_split_input,
211
  ],
212
- outputs=[example_btn, loading_status],
213
  )
214
 
215
  gr.on(
@@ -254,7 +215,7 @@ def get_demo():
254
  )
255
 
256
  def enable_run_btn(run_inference, inference_token, model_id, dataset_id, dataset_config, dataset_split):
257
- if run_inference and inference_token == "":
258
  return gr.update(interactive=False)
259
  if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
260
  return gr.update(interactive=False)
 
8
  align_columns_and_show_prediction,
9
  check_dataset,
10
  precheck_model_ds_enable_example_btn,
 
11
  try_submit,
12
  write_column_mapping_to_config,
13
  )
14
+ from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD, USE_INFERENCE_API_TIP
15
 
16
  MAX_LABELS = 40
17
  MAX_FEATURES = 20
 
79
  column_mappings.append(gr.Dropdown(visible=False))
80
 
81
  with gr.Accordion(label="Model Wrap Advance Config", open=True):
82
+ gr.HTML(USE_INFERENCE_API_TIP)
83
+
84
+ run_inference = gr.Checkbox(value=True, label="Run with Inference API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  inference_token = gr.Textbox(
86
  placeholder="hf-xxxxxxxxxxxxxxxxxxxx",
87
  value="",
 
90
  interactive=True,
91
  )
92
 
 
93
  with gr.Accordion(label="Scanner Advance Config (optional)", open=False):
94
  scanners = gr.CheckboxGroup(label="Scan Settings", visible=True)
95
 
 
120
  every=0.5,
121
  )
122
 
123
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  scanners.change(write_scanners, inputs=[scanners, uid_label])
125
 
 
 
 
 
 
 
126
  gr.on(
127
  triggers=[model_id_input.change],
128
  fn=get_related_datasets_from_leaderboard,
129
  inputs=[model_id_input],
130
  outputs=[dataset_id_input],
131
+ ).then(fn=check_dataset, inputs=[dataset_id_input], outputs=[dataset_config_input, dataset_split_input, loading_status])
132
+
133
+ gr.on(
134
+ triggers=[dataset_id_input.input],
135
+ fn=check_dataset,
136
+ inputs=[dataset_id_input],
137
+ outputs=[dataset_config_input, dataset_split_input, loading_status]
138
  )
139
 
140
  gr.on(
 
170
  dataset_config_input,
171
  dataset_split_input,
172
  ],
173
+ outputs=[example_btn, first_line_ds, loading_status],
174
  )
175
 
176
  gr.on(
 
215
  )
216
 
217
  def enable_run_btn(run_inference, inference_token, model_id, dataset_id, dataset_config, dataset_split):
218
+ if not run_inference or inference_token == "":
219
  return gr.update(interactive=False)
220
  if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
221
  return gr.update(interactive=False)
text_classification.py CHANGED
@@ -5,15 +5,13 @@ import datasets
5
  import huggingface_hub
6
  import pandas as pd
7
  from transformers import pipeline
 
 
8
 
 
9
 
10
- def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
11
- if not dataset_config:
12
- dataset_config = "default"
13
- if not split:
14
- split = "train"
15
  try:
16
- ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
  label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
19
  if len(label_keys) == 0: # no labels found
@@ -29,12 +27,60 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
29
  return labels, features
30
  except Exception as e:
31
  logging.warning(
32
- f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
33
  )
34
  return None, None
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def check_model(model_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
  task = huggingface_hub.model_info(model_id).pipeline_tag
40
  except Exception:
@@ -207,7 +253,7 @@ def check_dataset_features_validity(d_id, config, split):
207
  return df, dataset_features
208
 
209
 
210
- def get_example_prediction(ppl, dataset_id, dataset_config, dataset_split):
211
  # get a sample prediction from the model on the dataset
212
  prediction_input = None
213
  prediction_result = None
@@ -220,9 +266,13 @@ def get_example_prediction(ppl, dataset_id, dataset_config, dataset_split):
220
  else:
221
  prediction_input = ds[0]["text"]
222
 
223
- print("prediction_input", prediction_input)
224
- results = ppl(prediction_input, top_k=None)
225
- # Display results in original label and mapped label
 
 
 
 
226
  prediction_result = {
227
  f'{result["label"]}': result["score"] for result in results
228
  }
@@ -298,4 +348,4 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
298
  prediction_result,
299
  id2label_df,
300
  feature_map_df,
301
- )
 
5
  import huggingface_hub
6
  import pandas as pd
7
  from transformers import pipeline
8
+ import requests
9
+ import os
10
 
11
+ HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
12
 
13
+ def get_labels_and_features_from_dataset(ds):
 
 
 
 
14
  try:
 
15
  dataset_features = ds.features
16
  label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
17
  if len(label_keys) == 0: # no labels found
 
27
  return labels, features
28
  except Exception as e:
29
  logging.warning(
30
+ f"Get Labels/Features Failed for dataset: {e}"
31
  )
32
  return None, None
33
 
34
+ def check_model_task(model_id):
35
+ # check if model is valid on huggingface
36
+ try:
37
+ task = huggingface_hub.model_info(model_id).pipeline_tag
38
+ if task is None:
39
+ return None
40
+ return task
41
+ except Exception:
42
+ return None
43
+
44
+ def get_model_labels(model_id, example_input):
45
+ hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
46
+ payload = {"inputs": example_input, "options": {"use_cache": True}}
47
+ response = hf_inference_api(model_id, hf_token, payload)
48
+ if "error" in response:
49
+ return None
50
+ return extract_from_response(response, "label")
51
+
52
+ def extract_from_response(data, key):
53
+ results = []
54
+
55
+ if isinstance(data, dict):
56
+ res = data.get(key)
57
+ if res is not None:
58
+ results.append(res)
59
 
60
+ for value in data.values():
61
+ results.extend(extract_from_response(value, key))
62
+
63
+ elif isinstance(data, list):
64
+ for element in data:
65
+ results.extend(extract_from_response(element, key))
66
+
67
+ return results
68
+
69
+ def hf_inference_api(model_id, hf_token, payload):
70
+ hf_inference_api_endpoint = os.environ.get(
71
+ "HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
72
+ )
73
+ url = f"{hf_inference_api_endpoint}/models/{model_id}"
74
+ headers = {"Authorization": f"Bearer {hf_token}"}
75
+ response = requests.post(url, headers=headers, json=payload)
76
+ if response.status_code != 200:
77
+ logging.ERROR(f"Request to inference API returns {response.status_code}")
78
+ try:
79
+ return response.json()
80
+ except Exception:
81
+ return {"error": response.content}
82
+
83
+ def check_model_pipeline(model_id):
84
  try:
85
  task = huggingface_hub.model_info(model_id).pipeline_tag
86
  except Exception:
 
253
  return df, dataset_features
254
 
255
 
256
+ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
257
  # get a sample prediction from the model on the dataset
258
  prediction_input = None
259
  prediction_result = None
 
266
  else:
267
  prediction_input = ds[0]["text"]
268
 
269
+ hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
270
+ payload = {"inputs": prediction_input, "options": {"use_cache": True}}
271
+ results = hf_inference_api(model_id, hf_token, payload)
272
+ while isinstance(results, list):
273
+ if isinstance(results[0], dict):
274
+ break
275
+ results = results[0]
276
  prediction_result = {
277
  f'{result["label"]}': result["score"] for result in results
278
  }
 
348
  prediction_result,
349
  id2label_df,
350
  feature_map_df,
351
+ )
text_classification_ui_helpers.py CHANGED
@@ -9,7 +9,6 @@ import leaderboard
9
  import datasets
10
  import gradio as gr
11
  import pandas as pd
12
- from transformers.pipelines import TextClassificationPipeline
13
 
14
  from io_utils import (
15
  get_yaml_path,
@@ -19,7 +18,7 @@ from io_utils import (
19
  write_log_to_user_file,
20
  )
21
  from text_classification import (
22
- check_model,
23
  get_example_prediction,
24
  get_labels_and_features_from_dataset,
25
  )
@@ -43,72 +42,55 @@ HF_GSK_HUB_HF_TOKEN = "GSK_HF_TOKEN"
43
  HF_GSK_HUB_UNLOCK_TOKEN = "GSK_HUB_UNLOCK_TOKEN"
44
 
45
  LEADERBOARD = "giskard-bot/evaluator-leaderboard"
 
 
 
 
 
46
  def get_related_datasets_from_leaderboard(model_id):
47
  records = leaderboard.records
48
  model_records = records[records["model_id"] == model_id]
49
- datasets_unique = model_records["dataset_id"].unique()
 
50
  if len(datasets_unique) == 0:
51
  all_unique_datasets = list(records["dataset_id"].unique())
52
- print(type(all_unique_datasets), all_unique_datasets)
53
  return gr.update(choices=all_unique_datasets, value="")
 
54
  return gr.update(choices=datasets_unique, value=datasets_unique[0])
55
 
56
 
57
  logger = logging.getLogger(__file__)
58
 
59
 
60
- def check_dataset(dataset_id, dataset_config=None, dataset_split=None):
61
- configs = ["default"]
62
- splits = ["default"]
63
- logger.info(f"Loading {dataset_id}, {dataset_config}, {dataset_split}")
64
  try:
65
  configs = datasets.get_dataset_config_names(dataset_id)
 
 
 
 
 
 
66
  splits = list(
67
- datasets.load_dataset(
68
- dataset_id, configs[0] if not dataset_config else dataset_config
69
- ).keys()
 
 
 
 
 
70
  )
71
- if dataset_config == None:
72
- dataset_config = configs[0]
73
- dataset_split = splits[0]
74
- elif dataset_split == None:
75
- dataset_split = splits[0]
76
  except Exception as e:
77
- # Dataset may not exist
78
- logger.warn(
79
- f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
 
 
80
  )
81
- if dataset_config == None:
82
- return (
83
- gr.Dropdown(configs, value=configs[0], visible=True),
84
- gr.Dropdown(splits, value=splits[0], visible=True),
85
- gr.DataFrame(pd.DataFrame(), visible=False),
86
- "",
87
- )
88
- elif dataset_split == None:
89
- return (
90
- gr.Dropdown(configs, value=dataset_config, visible=True),
91
- gr.Dropdown(splits, value=splits[0], visible=True),
92
- gr.DataFrame(pd.DataFrame(), visible=False),
93
- "",
94
- )
95
-
96
- dataset_dict = datasets.load_dataset(dataset_id, dataset_config)
97
- dataframe: pd.DataFrame = dataset_dict[dataset_split].to_pandas().head(5)
98
- return (
99
- gr.Dropdown(configs, value=dataset_config, visible=True),
100
- gr.Dropdown(splits, value=dataset_split, visible=True),
101
- gr.DataFrame(dataframe, visible=True),
102
- "",
103
- )
104
 
105
 
106
- def select_run_mode(run_inf):
107
- if run_inf:
108
- return gr.update(visible=True)
109
- else:
110
- return gr.update(visible=False)
111
-
112
 
113
  def write_column_mapping_to_config(uid, *labels):
114
  # TODO: Substitute 'text' with more features for zero-shot
@@ -144,8 +126,7 @@ def export_mappings(all_mappings, key, subkeys, values):
144
  return all_mappings
145
 
146
 
147
- def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid):
148
- model_labels = list(model_id2label.values())
149
  all_mappings = read_column_mapping(uid)
150
  # For flattened raw datasets with no labels
151
  # check if there are shared labels between model and dataset
@@ -163,7 +144,7 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
163
  gr.Dropdown(
164
  label=f"{label}",
165
  choices=model_labels,
166
- value=model_id2label[i % len(model_labels)],
167
  interactive=True,
168
  visible=True,
169
  )
@@ -195,25 +176,37 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
195
  def precheck_model_ds_enable_example_btn(
196
  model_id, dataset_id, dataset_config, dataset_split
197
  ):
198
- ppl = check_model(model_id)
199
- if ppl is None or not isinstance(ppl, TextClassificationPipeline):
200
  gr.Warning("Please check your model.")
201
  return gr.update(interactive=False), ""
202
- ds_labels, ds_features = get_labels_and_features_from_dataset(
203
- dataset_id, dataset_config, dataset_split
204
- )
205
- if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
206
- gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
207
- return gr.update(interactive=False), ""
208
 
209
- return gr.update(interactive=True), ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
 
212
  def align_columns_and_show_prediction(
213
  model_id, dataset_id, dataset_config, dataset_split, uid, run_inference, inference_token
214
  ):
215
- ppl = check_model(model_id)
216
- if ppl is None or not isinstance(ppl, TextClassificationPipeline):
217
  gr.Warning("Please check your model.")
218
  return (
219
  gr.update(visible=False),
@@ -228,20 +221,15 @@ def align_columns_and_show_prediction(
228
  gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
229
  ]
230
 
231
- if ppl is None: # pipeline not found
232
- gr.Warning("Model not found")
233
- return (
234
- gr.update(visible=False),
235
- gr.update(visible=False),
236
- gr.update(visible=False, open=False),
237
- gr.update(interactive=False),
238
- *dropdown_placement,
239
- )
240
- model_id2label = ppl.model.config.id2label
241
- ds_labels, ds_features = get_labels_and_features_from_dataset(
242
- dataset_id, dataset_config, dataset_split
243
  )
244
 
 
 
 
 
 
245
  # when dataset does not have labels or features
246
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
247
  gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
@@ -257,14 +245,14 @@ def align_columns_and_show_prediction(
257
  column_mappings = list_labels_and_features_from_dataset(
258
  ds_labels,
259
  ds_features,
260
- model_id2label,
261
  uid,
262
  )
263
 
264
  # when labels or features are not aligned
265
  # show manually column mapping
266
  if (
267
- collections.Counter(model_id2label.values()) != collections.Counter(ds_labels)
268
  or ds_features[0] != "text"
269
  ):
270
  return (
@@ -276,9 +264,6 @@ def align_columns_and_show_prediction(
276
  *column_mappings,
277
  )
278
 
279
- prediction_input, prediction_output = get_example_prediction(
280
- ppl, dataset_id, dataset_config, dataset_split
281
- )
282
  return (
283
  gr.update(value=get_styled_input(prediction_input), visible=True),
284
  gr.update(value=prediction_output, visible=True),
@@ -322,10 +307,10 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
322
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
323
  leaderboard_dataset = LEADERBOARD
324
 
325
- inference_type = "hf_pipeline"
326
- if inference and inference_token:
327
  inference_type = "hf_inference_api"
328
 
 
329
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
330
  command = [
331
  "giskard_scanner",
@@ -354,6 +339,7 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
354
  "--inference_api_token",
355
  inference_token,
356
  ]
 
357
  # The token to publish post
358
  if os.environ.get(HF_WRITE_TOKEN):
359
  command.append("--hf_token")
 
9
  import datasets
10
  import gradio as gr
11
  import pandas as pd
 
12
 
13
  from io_utils import (
14
  get_yaml_path,
 
18
  write_log_to_user_file,
19
  )
20
  from text_classification import (
21
+ check_model_task,
22
  get_example_prediction,
23
  get_labels_and_features_from_dataset,
24
  )
 
42
  HF_GSK_HUB_UNLOCK_TOKEN = "GSK_HUB_UNLOCK_TOKEN"
43
 
44
  LEADERBOARD = "giskard-bot/evaluator-leaderboard"
45
+
46
+ global ds_dict, ds_config
47
+ ds_dict = None
48
+ ds_config = None
49
+
50
  def get_related_datasets_from_leaderboard(model_id):
51
  records = leaderboard.records
52
  model_records = records[records["model_id"] == model_id]
53
+ datasets_unique = list(model_records["dataset_id"].unique())
54
+
55
  if len(datasets_unique) == 0:
56
  all_unique_datasets = list(records["dataset_id"].unique())
 
57
  return gr.update(choices=all_unique_datasets, value="")
58
+
59
  return gr.update(choices=datasets_unique, value=datasets_unique[0])
60
 
61
 
62
  logger = logging.getLogger(__file__)
63
 
64
 
65
+ def check_dataset(dataset_id):
66
+ logger.info(f"Loading {dataset_id}")
 
 
67
  try:
68
  configs = datasets.get_dataset_config_names(dataset_id)
69
+ if len(configs) == 0:
70
+ return (
71
+ gr.update(),
72
+ gr.update(),
73
+ ""
74
+ )
75
  splits = list(
76
+ datasets.load_dataset(
77
+ dataset_id, configs[0]
78
+ ).keys()
79
+ )
80
+ return (
81
+ gr.update(choices=configs, value=configs[0], visible=True),
82
+ gr.update(choices=splits, value=splits[0], visible=True),
83
+ ""
84
  )
 
 
 
 
 
85
  except Exception as e:
86
+ logger.warn(f"Check your dataset {dataset_id}: {e}")
87
+ return (
88
+ gr.update(),
89
+ gr.update(),
90
+ ""
91
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
 
 
 
 
 
 
94
 
95
  def write_column_mapping_to_config(uid, *labels):
96
  # TODO: Substitute 'text' with more features for zero-shot
 
126
  return all_mappings
127
 
128
 
129
+ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid):
 
130
  all_mappings = read_column_mapping(uid)
131
  # For flattened raw datasets with no labels
132
  # check if there are shared labels between model and dataset
 
144
  gr.Dropdown(
145
  label=f"{label}",
146
  choices=model_labels,
147
+ value=model_labels[i % len(model_labels)],
148
  interactive=True,
149
  visible=True,
150
  )
 
176
  def precheck_model_ds_enable_example_btn(
177
  model_id, dataset_id, dataset_config, dataset_split
178
  ):
179
+ model_task = check_model_task(model_id)
180
+ if model_task is None or model_task != "text-classification":
181
  gr.Warning("Please check your model.")
182
  return gr.update(interactive=False), ""
 
 
 
 
 
 
183
 
184
+ if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
185
+ return (gr.update(), gr.update(), "")
186
+
187
+ try:
188
+ ds = datasets.load_dataset(dataset_id, dataset_config)
189
+ df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
190
+ ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])
191
+
192
+ if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
193
+ gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
194
+ return (gr.update(interactive=False), gr.update(value=df, visible=True), "")
195
+
196
+ return (gr.update(interactive=True), gr.update(value=df, visible=True), "")
197
+ except Exception as e:
198
+ # Config or split wrong
199
+ gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
200
+ return (gr.update(interactive=False), gr.update(value=pd.DataFrame(), visible=False), "")
201
+
202
+
203
 
204
 
205
  def align_columns_and_show_prediction(
206
  model_id, dataset_id, dataset_config, dataset_split, uid, run_inference, inference_token
207
  ):
208
+ model_task = check_model_task(model_id)
209
+ if model_task is None or model_task != "text-classification":
210
  gr.Warning("Please check your model.")
211
  return (
212
  gr.update(visible=False),
 
221
  gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
222
  ]
223
 
224
+ prediction_input, prediction_output = get_example_prediction(
225
+ model_id, dataset_id, dataset_config, dataset_split
 
 
 
 
 
 
 
 
 
 
226
  )
227
 
228
+ model_labels = list(prediction_output.keys())
229
+
230
+ ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split]
231
+ ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
232
+
233
  # when dataset does not have labels or features
234
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
235
  gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
 
245
  column_mappings = list_labels_and_features_from_dataset(
246
  ds_labels,
247
  ds_features,
248
+ model_labels,
249
  uid,
250
  )
251
 
252
  # when labels or features are not aligned
253
  # show manually column mapping
254
  if (
255
+ collections.Counter(model_labels) != collections.Counter(ds_labels)
256
  or ds_features[0] != "text"
257
  ):
258
  return (
 
264
  *column_mappings,
265
  )
266
 
 
 
 
267
  return (
268
  gr.update(value=get_styled_input(prediction_input), visible=True),
269
  gr.update(value=prediction_output, visible=True),
 
307
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
308
  leaderboard_dataset = LEADERBOARD
309
 
310
+ if inference:
 
311
  inference_type = "hf_inference_api"
312
 
313
+
314
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
315
  command = [
316
  "giskard_scanner",
 
339
  "--inference_api_token",
340
  inference_token,
341
  ]
342
+
343
  # The token to publish post
344
  if os.environ.get(HF_WRITE_TOKEN):
345
  command.append("--hf_token")
wordings.py CHANGED
@@ -38,7 +38,26 @@ MAPPING_STYLED_ERROR_WARNING = """
38
  </h3>
39
  """
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def get_styled_input(input):
42
- return f"""<h3 style="text-align: center;color: #5ec26a; background-color: #e2fbe8; border-radius: 8px; padding: 10px; ">
43
  Sample input: {input}
44
  </h3>"""
 
38
  </h3>
39
  """
40
 
41
+ USE_INFERENCE_API_TIP = """
42
+ We recommend to use
43
+ <a href="https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task">
44
+ Hugging Face Inference API
45
+ </a>
46
+ for the evaluation,
47
+ which requires your <a href="https://huggingface.co/settings/tokens">HF token</a>.
48
+ <br/>
49
+ Otherwise, an
50
+ <a href="https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.TextClassificationPipeline">
51
+ HF pipeline
52
+ </a>
53
+ will be created and run in this Space. It takes more time to get the result.
54
+ <br/>
55
+ <b>
56
+ Do not worry, your HF token is only used in this Space for your evaluation.
57
+ </b>
58
+ """
59
+
60
  def get_styled_input(input):
61
+ return f"""<h3 style="text-align: center;color: #4ca154; background-color: #e2fbe8; border-radius: 8px; padding: 10px; ">
62
  Sample input: {input}
63
  </h3>"""