GSK-2509 fix not standard label columns (go_emotions)

#29
by ZeroCommand - opened
app.py CHANGED
@@ -10,7 +10,7 @@ from run_jobs import start_process_run_job, stop_thread
10
  try:
11
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
12
  with gr.Tab("Text Classification"):
13
- get_demo_text_classification(demo)
14
  with gr.Tab("Leaderboard"):
15
  get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
 
10
  try:
11
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
12
  with gr.Tab("Text Classification"):
13
+ get_demo_text_classification()
14
  with gr.Tab("Leaderboard"):
15
  get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
app_text_classification.py CHANGED
@@ -2,17 +2,17 @@ import uuid
2
 
3
  import gradio as gr
4
 
5
- from io_utils import (get_logs_file, read_inference_type, read_scanners,
6
- write_inference_type, write_scanners)
7
  from text_classification_ui_helpers import (check_dataset_and_get_config,
8
  check_dataset_and_get_split,
9
- check_model_and_show_prediction,
10
  deselect_run_inference,
11
  select_run_mode, try_submit,
12
- write_column_mapping_to_config)
 
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
- MAX_LABELS = 20
16
  MAX_FEATURES = 20
17
 
18
  EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
@@ -20,7 +20,7 @@ EXAMPLE_DATA_ID = "tweet_eval"
20
  CONFIG_PATH = "./config.yaml"
21
 
22
 
23
- def get_demo(demo):
24
  with gr.Row():
25
  gr.Markdown(INTRODUCTION_MD)
26
  uid_label = gr.Textbox(
@@ -41,6 +41,13 @@ def get_demo(demo):
41
  dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
42
  dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
43
 
 
 
 
 
 
 
 
44
  with gr.Row():
45
  example_input = gr.HTML(visible=False)
46
  with gr.Row():
@@ -55,23 +62,17 @@ def get_demo(demo):
55
  column_mappings = []
56
  with gr.Row():
57
  with gr.Column():
 
58
  for _ in range(MAX_LABELS):
59
  column_mappings.append(gr.Dropdown(visible=False))
60
  with gr.Column():
 
61
  for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
62
  column_mappings.append(gr.Dropdown(visible=False))
63
 
64
  with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
65
  run_local = gr.Checkbox(value=True, label="Run in this Space")
66
- run_inference = gr.Checkbox(value="False", label="Run with Inference API")
67
-
68
- @gr.on(triggers=[uid_label.change], inputs=[uid_label], outputs=[run_inference])
69
- def get_run_mode(uid):
70
- return gr.update(
71
- value=read_inference_type(uid) == "hf_inference_api"
72
- and not run_local.value
73
- )
74
-
75
  inference_token = gr.Textbox(
76
  value="",
77
  label="HF Token for Inference API",
@@ -97,13 +98,12 @@ def get_demo(demo):
97
  run_btn = gr.Button(
98
  "Get Evaluation Result",
99
  variant="primary",
100
- interactive=True,
101
  size="lg",
102
  )
103
 
104
  with gr.Row():
105
- logs = gr.Textbox(label="Giskard Bot Evaluation Log:", visible=False)
106
- demo.load(get_logs_file, None, logs, every=0.5)
107
 
108
  dataset_id_input.change(
109
  check_dataset_and_get_config,
@@ -121,7 +121,7 @@ def get_demo(demo):
121
 
122
  run_inference.change(
123
  select_run_mode,
124
- inputs=[run_inference, inference_token, uid_label],
125
  outputs=[inference_token, run_local],
126
  )
127
 
@@ -131,17 +131,10 @@ def get_demo(demo):
131
  outputs=[inference_token, run_inference],
132
  )
133
 
134
- inference_token.change(
135
- write_inference_type, inputs=[run_inference, inference_token, uid_label]
136
- )
137
-
138
  gr.on(
139
  triggers=[label.change for label in column_mappings],
140
  fn=write_column_mapping_to_config,
141
  inputs=[
142
- dataset_id_input,
143
- dataset_config_input,
144
- dataset_split_input,
145
  uid_label,
146
  *column_mappings,
147
  ],
@@ -152,9 +145,6 @@ def get_demo(demo):
152
  triggers=[label.input for label in column_mappings],
153
  fn=write_column_mapping_to_config,
154
  inputs=[
155
- dataset_id_input,
156
- dataset_config_input,
157
- dataset_split_input,
158
  uid_label,
159
  *column_mappings,
160
  ],
@@ -165,19 +155,33 @@ def get_demo(demo):
165
  model_id_input.change,
166
  dataset_id_input.change,
167
  dataset_config_input.change,
168
- dataset_split_input.change,
 
 
 
 
 
 
 
 
 
 
 
 
169
  ],
170
- fn=check_model_and_show_prediction,
171
  inputs=[
172
  model_id_input,
173
  dataset_id_input,
174
  dataset_config_input,
175
  dataset_split_input,
 
176
  ],
177
  outputs=[
178
  example_input,
179
  example_prediction,
180
  column_mapping_accordion,
 
181
  *column_mappings,
182
  ],
183
  )
@@ -193,6 +197,8 @@ def get_demo(demo):
193
  dataset_config_input,
194
  dataset_split_input,
195
  run_local,
 
 
196
  uid_label,
197
  ],
198
  outputs=[run_btn, logs],
@@ -203,12 +209,10 @@ def get_demo(demo):
203
 
204
  gr.on(
205
  triggers=[
206
- model_id_input.change,
207
- dataset_config_input.change,
208
- dataset_split_input.change,
209
- run_inference.change,
210
- run_local.change,
211
- scanners.change,
212
  ],
213
  fn=enable_run_btn,
214
  inputs=None,
@@ -216,8 +220,8 @@ def get_demo(demo):
216
  )
217
 
218
  gr.on(
219
- triggers=[label.change for label in column_mappings],
220
  fn=enable_run_btn,
221
- inputs=None,
222
  outputs=[run_btn],
223
  )
 
2
 
3
  import gradio as gr
4
 
5
+ from io_utils import (get_logs_file, read_scanners, write_scanners)
 
6
  from text_classification_ui_helpers import (check_dataset_and_get_config,
7
  check_dataset_and_get_split,
8
+ align_columns_and_show_prediction,
9
  deselect_run_inference,
10
  select_run_mode, try_submit,
11
+ write_column_mapping_to_config,
12
+ precheck_model_ds_enable_example_btn)
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
+ MAX_LABELS = 40
16
  MAX_FEATURES = 20
17
 
18
  EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
 
20
  CONFIG_PATH = "./config.yaml"
21
 
22
 
23
+ def get_demo():
24
  with gr.Row():
25
  gr.Markdown(INTRODUCTION_MD)
26
  uid_label = gr.Textbox(
 
41
  dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
42
  dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
43
 
44
+ with gr.Row():
45
+ example_btn = gr.Button(
46
+ "Auto-align Columns & Get Sample Prediction",
47
+ visible=True,
48
+ variant="primary",
49
+ interactive=False)
50
+
51
  with gr.Row():
52
  example_input = gr.HTML(visible=False)
53
  with gr.Row():
 
62
  column_mappings = []
63
  with gr.Row():
64
  with gr.Column():
65
+ gr.Markdown("# Label Mapping")
66
  for _ in range(MAX_LABELS):
67
  column_mappings.append(gr.Dropdown(visible=False))
68
  with gr.Column():
69
+ gr.Markdown("# Feature Mapping")
70
  for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
71
  column_mappings.append(gr.Dropdown(visible=False))
72
 
73
  with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
74
  run_local = gr.Checkbox(value=True, label="Run in this Space")
75
+ run_inference = gr.Checkbox(value=False, label="Run with Inference API")
 
 
 
 
 
 
 
 
76
  inference_token = gr.Textbox(
77
  value="",
78
  label="HF Token for Inference API",
 
98
  run_btn = gr.Button(
99
  "Get Evaluation Result",
100
  variant="primary",
101
+ interactive=False,
102
  size="lg",
103
  )
104
 
105
  with gr.Row():
106
+ logs = gr.Textbox(value=get_logs_file, label="Giskard Bot Evaluation Log:", visible=False, every=0.5)
 
107
 
108
  dataset_id_input.change(
109
  check_dataset_and_get_config,
 
121
 
122
  run_inference.change(
123
  select_run_mode,
124
+ inputs=[run_inference],
125
  outputs=[inference_token, run_local],
126
  )
127
 
 
131
  outputs=[inference_token, run_inference],
132
  )
133
 
 
 
 
 
134
  gr.on(
135
  triggers=[label.change for label in column_mappings],
136
  fn=write_column_mapping_to_config,
137
  inputs=[
 
 
 
138
  uid_label,
139
  *column_mappings,
140
  ],
 
145
  triggers=[label.input for label in column_mappings],
146
  fn=write_column_mapping_to_config,
147
  inputs=[
 
 
 
148
  uid_label,
149
  *column_mappings,
150
  ],
 
155
  model_id_input.change,
156
  dataset_id_input.change,
157
  dataset_config_input.change,
158
+ dataset_split_input.change],
159
+ fn=precheck_model_ds_enable_example_btn,
160
+ inputs=[
161
+ model_id_input,
162
+ dataset_id_input,
163
+ dataset_config_input,
164
+ dataset_split_input,
165
+ ],
166
+ outputs=[example_btn])
167
+
168
+ gr.on(
169
+ triggers=[
170
+ example_btn.click,
171
  ],
172
+ fn=align_columns_and_show_prediction,
173
  inputs=[
174
  model_id_input,
175
  dataset_id_input,
176
  dataset_config_input,
177
  dataset_split_input,
178
+ uid_label,
179
  ],
180
  outputs=[
181
  example_input,
182
  example_prediction,
183
  column_mapping_accordion,
184
+ run_btn,
185
  *column_mappings,
186
  ],
187
  )
 
197
  dataset_config_input,
198
  dataset_split_input,
199
  run_local,
200
+ run_inference,
201
+ inference_token,
202
  uid_label,
203
  ],
204
  outputs=[run_btn, logs],
 
209
 
210
  gr.on(
211
  triggers=[
212
+ run_inference.input,
213
+ run_local.input,
214
+ inference_token.input,
215
+ scanners.input,
 
 
216
  ],
217
  fn=enable_run_btn,
218
  inputs=None,
 
220
  )
221
 
222
  gr.on(
223
+ triggers=[label.input for label in column_mappings],
224
  fn=enable_run_btn,
225
+ inputs=column_mappings,
226
  outputs=[run_btn],
227
  )
io_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import subprocess
3
 
4
  import yaml
@@ -6,6 +7,7 @@ import yaml
6
  import pipe
7
 
8
  YAML_PATH = "./cicd/configs"
 
9
 
10
 
11
  class Dumper(yaml.Dumper):
@@ -28,7 +30,6 @@ def read_scanners(uid):
28
  with open(get_yaml_path(uid), "r") as f:
29
  config = yaml.load(f, Loader=yaml.FullLoader)
30
  scanners = config.get("detectors", [])
31
- f.close()
32
  return scanners
33
 
34
 
@@ -38,11 +39,9 @@ def write_scanners(scanners, uid):
38
  config = yaml.load(f, Loader=yaml.FullLoader)
39
  if config:
40
  config["detectors"] = scanners
41
- f.close()
42
  # save scanners to detectors in yaml
43
  with open(get_yaml_path(uid), "w") as f:
44
  yaml.dump(config, f, Dumper=Dumper)
45
- f.close()
46
 
47
 
48
  # read model_type from yaml file
@@ -51,7 +50,6 @@ def read_inference_type(uid):
51
  with open(get_yaml_path(uid), "r") as f:
52
  config = yaml.load(f, Loader=yaml.FullLoader)
53
  inference_type = config.get("inference_type", "")
54
- f.close()
55
  return inference_type
56
 
57
 
@@ -66,11 +64,9 @@ def write_inference_type(use_inference, inference_token, uid):
66
  config["inference_type"] = "hf_pipeline"
67
  # FIXME: A quick and temp fix for missing token
68
  config["inference_token"] = ""
69
- f.close()
70
  # save inference_type to inference_type in yaml
71
  with open(get_yaml_path(uid), "w") as f:
72
  yaml.dump(config, f, Dumper=Dumper)
73
- f.close()
74
 
75
 
76
  # read column mapping from yaml file
@@ -80,7 +76,6 @@ def read_column_mapping(uid):
80
  config = yaml.load(f, Loader=yaml.FullLoader)
81
  if config:
82
  column_mapping = config.get("column_mapping", dict())
83
- f.close()
84
  return column_mapping
85
 
86
 
@@ -88,7 +83,6 @@ def read_column_mapping(uid):
88
  def write_column_mapping(mapping, uid):
89
  with open(get_yaml_path(uid), "r") as f:
90
  config = yaml.load(f, Loader=yaml.FullLoader)
91
- f.close()
92
 
93
  if config is None:
94
  return
@@ -96,10 +90,9 @@ def write_column_mapping(mapping, uid):
96
  del config["column_mapping"]
97
  else:
98
  config["column_mapping"] = mapping
99
-
100
  with open(get_yaml_path(uid), "w") as f:
101
- yaml.dump(config, f, Dumper=Dumper)
102
- f.close()
103
 
104
 
105
  # convert column mapping dataframe to json
@@ -113,21 +106,20 @@ def convert_column_mapping_to_json(df, label=""):
113
 
114
  def get_logs_file():
115
  try:
116
- file = open(f"./tmp/temp_log", "r")
117
- return file.read()
118
  except Exception:
119
  return "Log file does not exist"
120
 
121
 
122
- def write_log_to_user_file(id, log):
123
- with open(f"./tmp/temp_log", "a") as f:
124
  f.write(log)
125
- f.close()
126
 
127
 
128
- def save_job_to_pipe(id, job, description, lock):
129
  with lock:
130
- pipe.jobs.append((id, job, description))
131
 
132
 
133
  def pop_job_from_pipe():
@@ -135,14 +127,17 @@ def pop_job_from_pipe():
135
  return
136
  job_info = pipe.jobs.pop()
137
  pipe.current = job_info[2]
138
- write_log_to_user_file(job_info[0], f"Running job id {job_info[0]}\n")
 
139
  command = job_info[1]
140
 
141
- log_file = open(f"./tmp/temp_log", "a")
142
- p = subprocess.Popen(
143
- command,
144
- stdout=log_file,
145
- stderr=log_file,
146
- )
147
- p.wait()
 
 
148
  pipe.current = None
 
1
  import os
2
+ from pathlib import Path
3
  import subprocess
4
 
5
  import yaml
 
7
  import pipe
8
 
9
  YAML_PATH = "./cicd/configs"
10
+ LOG_FILE = "temp_log"
11
 
12
 
13
  class Dumper(yaml.Dumper):
 
30
  with open(get_yaml_path(uid), "r") as f:
31
  config = yaml.load(f, Loader=yaml.FullLoader)
32
  scanners = config.get("detectors", [])
 
33
  return scanners
34
 
35
 
 
39
  config = yaml.load(f, Loader=yaml.FullLoader)
40
  if config:
41
  config["detectors"] = scanners
 
42
  # save scanners to detectors in yaml
43
  with open(get_yaml_path(uid), "w") as f:
44
  yaml.dump(config, f, Dumper=Dumper)
 
45
 
46
 
47
  # read model_type from yaml file
 
50
  with open(get_yaml_path(uid), "r") as f:
51
  config = yaml.load(f, Loader=yaml.FullLoader)
52
  inference_type = config.get("inference_type", "")
 
53
  return inference_type
54
 
55
 
 
64
  config["inference_type"] = "hf_pipeline"
65
  # FIXME: A quick and temp fix for missing token
66
  config["inference_token"] = ""
 
67
  # save inference_type to inference_type in yaml
68
  with open(get_yaml_path(uid), "w") as f:
69
  yaml.dump(config, f, Dumper=Dumper)
 
70
 
71
 
72
  # read column mapping from yaml file
 
76
  config = yaml.load(f, Loader=yaml.FullLoader)
77
  if config:
78
  column_mapping = config.get("column_mapping", dict())
 
79
  return column_mapping
80
 
81
 
 
83
  def write_column_mapping(mapping, uid):
84
  with open(get_yaml_path(uid), "r") as f:
85
  config = yaml.load(f, Loader=yaml.FullLoader)
 
86
 
87
  if config is None:
88
  return
 
90
  del config["column_mapping"]
91
  else:
92
  config["column_mapping"] = mapping
 
93
  with open(get_yaml_path(uid), "w") as f:
94
+ # yaml Dumper will by default sort the keys
95
+ yaml.dump(config, f, Dumper=Dumper, sort_keys=False)
96
 
97
 
98
  # convert column mapping dataframe to json
 
106
 
107
  def get_logs_file():
108
  try:
109
+ with open(LOG_FILE, "r") as file:
110
+ return file.read()
111
  except Exception:
112
  return "Log file does not exist"
113
 
114
 
115
+ def write_log_to_user_file(task_id, log):
116
+ with open(f"./tmp/{task_id}.log", "a") as f:
117
  f.write(log)
 
118
 
119
 
120
+ def save_job_to_pipe(task_id, job, description, lock):
121
  with lock:
122
+ pipe.jobs.append((task_id, job, description))
123
 
124
 
125
  def pop_job_from_pipe():
 
127
  return
128
  job_info = pipe.jobs.pop()
129
  pipe.current = job_info[2]
130
+ task_id = job_info[0]
131
+ write_log_to_user_file(task_id, f"Running job id {task_id}\n")
132
  command = job_info[1]
133
 
134
+ # Link to LOG_FILE
135
+ log_file_path = Path(LOG_FILE)
136
+ if log_file_path.exists():
137
+ log_file_path.unlink()
138
+ os.symlink(f"./tmp/{task_id}.log", LOG_FILE)
139
+
140
+ with open(f"./tmp/{task_id}.log", "a") as log_file:
141
+ p = subprocess.Popen(command, stdout=log_file, stderr=subprocess.STDOUT)
142
+ p.wait()
143
  pipe.current = None
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- giskard >= 2.1.0, < 2.3.0
2
  huggingface_hub
3
  torch==2.0.1
4
  transformers
 
1
+ giskard==2.1.2
2
  huggingface_hub
3
  torch==2.0.1
4
  transformers
text_classification.py CHANGED
@@ -15,8 +15,17 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
15
  try:
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
- labels = dataset_features["label"].names
19
- features = [f for f in dataset_features.keys() if f != "label"]
 
 
 
 
 
 
 
 
 
20
  return labels, features
21
  except Exception as e:
22
  logging.warning(
 
15
  try:
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
+ label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
19
+ if len(label_keys) == 0: # no labels found
20
+ # return everything for post processing
21
+ return list(dataset_features.keys()), list(dataset_features.keys())
22
+ if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
23
+ if hasattr(dataset_features[label_keys[0]], 'feature'):
24
+ label_feat = dataset_features[label_keys[0]].feature
25
+ labels = label_feat.names
26
+ else:
27
+ labels = [dataset_features[label_keys[0]].names]
28
+ features = [f for f in dataset_features.keys() if not f.startswith("label")]
29
  return labels, features
30
  except Exception as e:
31
  logging.warning(
text_classification_ui_helpers.py CHANGED
@@ -10,7 +10,7 @@ from transformers.pipelines import TextClassificationPipeline
10
  from wordings import get_styled_input
11
 
12
  from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe,
13
- write_column_mapping, write_inference_type,
14
  write_log_to_user_file)
15
  from text_classification import (check_model, get_example_prediction,
16
  get_labels_and_features_from_dataset)
@@ -18,7 +18,7 @@ from wordings import (CHECK_CONFIG_OR_SPLIT_RAW,
18
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
19
  MAPPING_STYLED_ERROR_WARNING)
20
 
21
- MAX_LABELS = 20
22
  MAX_FEATURES = 20
23
 
24
  HF_REPO_ID = "HF_REPO_ID"
@@ -51,10 +51,8 @@ def check_dataset_and_get_split(dataset_id, dataset_config):
51
  pass
52
 
53
 
54
- def select_run_mode(run_inf, inf_token, uid):
55
  if run_inf:
56
- if len(inf_token) > 0:
57
- write_inference_type(run_inf, inf_token, uid)
58
  return (gr.update(visible=True), gr.update(value=False))
59
  else:
60
  return (gr.update(visible=False), gr.update(value=True))
@@ -68,46 +66,62 @@ def deselect_run_inference(run_local):
68
 
69
 
70
  def write_column_mapping_to_config(
71
- dataset_id, dataset_config, dataset_split, uid, *labels
72
  ):
73
  # TODO: Substitute 'text' with more features for zero-shot
74
  # we are not using ds features because we only support "text" for now
75
- ds_labels, _ = get_labels_and_features_from_dataset(
76
- dataset_id, dataset_config, dataset_split
77
- )
78
  if labels is None:
79
  return
 
 
80
 
81
- all_mappings = dict()
82
-
83
- if "labels" not in all_mappings.keys():
84
- all_mappings["labels"] = dict()
85
- for i, label in enumerate(labels[:MAX_LABELS]):
86
- if label:
87
- all_mappings["labels"][label] = ds_labels[i % len(ds_labels)]
88
- if "features" not in all_mappings.keys():
89
- all_mappings["features"] = dict()
90
- for _, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]):
91
- if feat:
92
- # TODO: Substitute 'text' with more features for zero-shot
93
- all_mappings["features"]["text"] = feat
94
  write_column_mapping(all_mappings, uid)
95
 
96
-
97
- def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  model_labels = list(model_id2label.values())
99
- len_model_labels = len(model_labels)
 
 
 
 
 
 
 
 
 
 
 
 
100
  lables = [
101
  gr.Dropdown(
102
  label=f"{label}",
103
  choices=model_labels,
104
- value=model_id2label[i % len_model_labels],
105
  interactive=True,
106
  visible=True,
107
  )
108
- for i, label in enumerate(ds_labels[:MAX_LABELS])
109
  ]
110
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
 
 
111
  # TODO: Substitute 'text' with more features for zero-shot
112
  features = [
113
  gr.Dropdown(
@@ -122,11 +136,27 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
122
  features += [
123
  gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
124
  ]
 
 
 
125
  return lables + features
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- def check_model_and_show_prediction(
129
- model_id, dataset_id, dataset_config, dataset_split
130
  ):
131
  ppl = check_model(model_id)
132
  if ppl is None or not isinstance(ppl, TextClassificationPipeline):
@@ -134,6 +164,8 @@ def check_model_and_show_prediction(
134
  return (
135
  gr.update(visible=False),
136
  gr.update(visible=False),
 
 
137
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
138
  )
139
 
@@ -147,6 +179,7 @@ def check_model_and_show_prediction(
147
  gr.update(visible=False),
148
  gr.update(visible=False),
149
  gr.update(visible=False, open=False),
 
150
  *dropdown_placement,
151
  )
152
  model_id2label = ppl.model.config.id2label
@@ -161,6 +194,7 @@ def check_model_and_show_prediction(
161
  gr.update(visible=False),
162
  gr.update(visible=False),
163
  gr.update(visible=False, open=False),
 
164
  *dropdown_placement,
165
  )
166
 
@@ -168,6 +202,7 @@ def check_model_and_show_prediction(
168
  ds_labels,
169
  ds_features,
170
  model_id2label,
 
171
  )
172
 
173
  # when labels or features are not aligned
@@ -180,6 +215,7 @@ def check_model_and_show_prediction(
180
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
181
  gr.update(visible=False),
182
  gr.update(visible=True, open=True),
 
183
  *column_mappings,
184
  )
185
 
@@ -190,13 +226,11 @@ def check_model_and_show_prediction(
190
  gr.update(value=get_styled_input(prediction_input), visible=True),
191
  gr.update(value=prediction_output, visible=True),
192
  gr.update(visible=True, open=False),
 
193
  *column_mappings,
194
  )
195
 
196
-
197
- def try_submit(m_id, d_id, config, split, local, uid):
198
- all_mappings = read_column_mapping(uid)
199
-
200
  if all_mappings is None:
201
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
202
  return (gr.update(interactive=True), gr.update(visible=False))
@@ -204,6 +238,8 @@ def try_submit(m_id, d_id, config, split, local, uid):
204
  if "labels" not in all_mappings.keys():
205
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
206
  return (gr.update(interactive=True), gr.update(visible=False))
 
 
207
  label_mapping = {}
208
  for i, label in zip(
209
  range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
@@ -214,73 +250,88 @@ def try_submit(m_id, d_id, config, split, local, uid):
214
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
215
  return (gr.update(interactive=True), gr.update(visible=False))
216
  feature_mapping = all_mappings["features"]
 
 
 
 
 
 
217
 
218
  leaderboard_dataset = None
219
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
220
  leaderboard_dataset = "ZeroCommand/test-giskard-report"
 
 
 
 
 
221
 
222
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
223
- if local:
224
- command = [
225
- "giskard_scanner",
226
- "--loader",
227
- "huggingface",
228
- "--model",
229
- m_id,
230
- "--dataset",
231
- d_id,
232
- "--dataset_config",
233
- config,
234
- "--dataset_split",
235
- split,
236
- "--hf_token",
237
- os.environ.get(HF_WRITE_TOKEN),
238
- "--discussion_repo",
239
- os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
240
- "--output_format",
241
- "markdown",
242
- "--output_portal",
243
- "huggingface",
244
- "--feature_mapping",
245
- json.dumps(feature_mapping),
246
- "--label_mapping",
247
- json.dumps(label_mapping),
248
- "--scan_config",
249
- get_yaml_path(uid),
250
- "--leaderboard_dataset",
251
- leaderboard_dataset,
252
- ]
253
- if os.environ.get(HF_GSK_HUB_KEY):
254
- command.append("--giskard_hub_api_key")
255
- command.append(os.environ.get(HF_GSK_HUB_KEY))
256
- if os.environ.get(HF_GSK_HUB_URL):
257
- command.append("--giskard_hub_url")
258
- command.append(os.environ.get(HF_GSK_HUB_URL))
259
- if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
260
- command.append("--giskard_hub_project_key")
261
- command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
262
- if os.environ.get(HF_GSK_HUB_HF_TOKEN):
263
- command.append("--giskard_hub_hf_token")
264
- command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
265
- if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
266
- command.append("--giskard_hub_unlock_token")
267
- command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))
268
-
269
- eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
270
- logging.info(f"Start local evaluation on {eval_str}")
271
- save_job_to_pipe(uid, command, eval_str, threading.Lock())
272
- write_log_to_user_file(
273
- uid,
274
- f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
275
- )
276
- gr.Info(f"Start local evaluation on {eval_str}")
 
 
 
 
277
 
278
- return (
279
- gr.update(interactive=False),
280
- gr.update(lines=5, visible=True, interactive=False),
281
- )
282
 
283
- else:
284
- gr.Info("TODO: Submit task to an endpoint")
285
 
286
- return (gr.update(interactive=True), gr.update(visible=False)) # Submit button
 
 
 
10
  from wordings import get_styled_input
11
 
12
  from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe,
13
+ write_column_mapping,
14
  write_log_to_user_file)
15
  from text_classification import (check_model, get_example_prediction,
16
  get_labels_and_features_from_dataset)
 
18
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
19
  MAPPING_STYLED_ERROR_WARNING)
20
 
21
+ MAX_LABELS = 40
22
  MAX_FEATURES = 20
23
 
24
  HF_REPO_ID = "HF_REPO_ID"
 
51
  pass
52
 
53
 
54
+ def select_run_mode(run_inf):
55
  if run_inf:
 
 
56
  return (gr.update(visible=True), gr.update(value=False))
57
  else:
58
  return (gr.update(visible=False), gr.update(value=True))
 
66
 
67
 
68
  def write_column_mapping_to_config(
69
+ uid, *labels
70
  ):
71
  # TODO: Substitute 'text' with more features for zero-shot
72
  # we are not using ds features because we only support "text" for now
73
+ all_mappings = read_column_mapping(uid)
74
+
 
75
  if labels is None:
76
  return
77
+ all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
78
+ all_mappings = export_mappings(all_mappings, "features", ["text"], labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)])
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  write_column_mapping(all_mappings, uid)
81
 
82
+ def export_mappings(all_mappings, key, subkeys, values):
83
+ if key not in all_mappings.keys():
84
+ all_mappings[key] = dict()
85
+ if subkeys is None:
86
+ subkeys = list(all_mappings[key].keys())
87
+
88
+ if not subkeys:
89
+ logging.debug(f"subkeys is empty for {key}")
90
+ return all_mappings
91
+
92
+ for i, subkey in enumerate(subkeys):
93
+ if subkey:
94
+ all_mappings[key][subkey] = values[i % len(values)]
95
+ return all_mappings
96
+
97
+ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid):
98
  model_labels = list(model_id2label.values())
99
+ all_mappings = read_column_mapping(uid)
100
+ # For flattened raw datasets with no labels
101
+ # check if there are shared labels between model and dataset
102
+ shared_labels = set(model_labels).intersection(set(ds_labels))
103
+ if shared_labels:
104
+ ds_labels = list(shared_labels)
105
+ if len(ds_labels) > MAX_LABELS:
106
+ ds_labels = ds_labels[:MAX_LABELS]
107
+ gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
108
+
109
+ ds_labels.sort()
110
+ model_labels.sort()
111
+
112
  lables = [
113
  gr.Dropdown(
114
  label=f"{label}",
115
  choices=model_labels,
116
+ value=model_id2label[i % len(model_labels)],
117
  interactive=True,
118
  visible=True,
119
  )
120
+ for i, label in enumerate(ds_labels)
121
  ]
122
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
123
+ all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)
124
+
125
  # TODO: Substitute 'text' with more features for zero-shot
126
  features = [
127
  gr.Dropdown(
 
136
  features += [
137
  gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
138
  ]
139
+ all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
140
+ write_column_mapping(all_mappings, uid)
141
+
142
  return lables + features
143
 
144
+ def precheck_model_ds_enable_example_btn(model_id, dataset_id, dataset_config, dataset_split):
145
+ ppl = check_model(model_id)
146
+ if ppl is None or not isinstance(ppl, TextClassificationPipeline):
147
+ gr.Warning("Please check your model.")
148
+ return gr.update(interactive=False)
149
+ ds_labels, ds_features = get_labels_and_features_from_dataset(
150
+ dataset_id, dataset_config, dataset_split
151
+ )
152
+ if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
153
+ gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
154
+ return gr.update(interactive=False)
155
+
156
+ return gr.update(interactive=True)
157
 
158
+ def align_columns_and_show_prediction(
159
+ model_id, dataset_id, dataset_config, dataset_split, uid
160
  ):
161
  ppl = check_model(model_id)
162
  if ppl is None or not isinstance(ppl, TextClassificationPipeline):
 
164
  return (
165
  gr.update(visible=False),
166
  gr.update(visible=False),
167
+ gr.update(visible=False, open=False),
168
+ gr.update(interactive=False),
169
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
170
  )
171
 
 
179
  gr.update(visible=False),
180
  gr.update(visible=False),
181
  gr.update(visible=False, open=False),
182
+ gr.update(interactive=False),
183
  *dropdown_placement,
184
  )
185
  model_id2label = ppl.model.config.id2label
 
194
  gr.update(visible=False),
195
  gr.update(visible=False),
196
  gr.update(visible=False, open=False),
197
+ gr.update(interactive=False),
198
  *dropdown_placement,
199
  )
200
 
 
202
  ds_labels,
203
  ds_features,
204
  model_id2label,
205
+ uid,
206
  )
207
 
208
  # when labels or features are not aligned
 
215
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
216
  gr.update(visible=False),
217
  gr.update(visible=True, open=True),
218
+ gr.update(interactive=True),
219
  *column_mappings,
220
  )
221
 
 
226
  gr.update(value=get_styled_input(prediction_input), visible=True),
227
  gr.update(value=prediction_output, visible=True),
228
  gr.update(visible=True, open=False),
229
+ gr.update(interactive=True),
230
  *column_mappings,
231
  )
232
 
233
+ def check_column_mapping_keys_validity(all_mappings):
 
 
 
234
  if all_mappings is None:
235
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
236
  return (gr.update(interactive=True), gr.update(visible=False))
 
238
  if "labels" not in all_mappings.keys():
239
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
240
  return (gr.update(interactive=True), gr.update(visible=False))
241
+
242
+ def construct_label_and_feature_mapping(all_mappings):
243
  label_mapping = {}
244
  for i, label in zip(
245
  range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
 
250
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
251
  return (gr.update(interactive=True), gr.update(visible=False))
252
  feature_mapping = all_mappings["features"]
253
+ return label_mapping, feature_mapping
254
+
255
+ def try_submit(m_id, d_id, config, split, local, inference, inference_token, uid):
256
+ all_mappings = read_column_mapping(uid)
257
+ check_column_mapping_keys_validity(all_mappings)
258
+ label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings)
259
 
260
  leaderboard_dataset = None
261
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
262
  leaderboard_dataset = "ZeroCommand/test-giskard-report"
263
+
264
+ if local:
265
+ inference_type = "hf_pipeline"
266
+ if inference and inference_token:
267
+ inference_type = "hf_inference_api"
268
 
269
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
270
+ command = [
271
+ "giskard_scanner",
272
+ "--loader",
273
+ "huggingface",
274
+ "--model",
275
+ m_id,
276
+ "--dataset",
277
+ d_id,
278
+ "--dataset_config",
279
+ config,
280
+ "--dataset_split",
281
+ split,
282
+ "--hf_token",
283
+ os.environ.get(HF_WRITE_TOKEN),
284
+ "--discussion_repo",
285
+ os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
286
+ "--output_format",
287
+ "markdown",
288
+ "--output_portal",
289
+ "huggingface",
290
+ "--feature_mapping",
291
+ json.dumps(feature_mapping),
292
+ "--label_mapping",
293
+ json.dumps(label_mapping),
294
+ "--scan_config",
295
+ get_yaml_path(uid),
296
+ "--leaderboard_dataset",
297
+ leaderboard_dataset,
298
+ "--inference_type",
299
+ inference_type,
300
+ "--inference_token",
301
+ inference_token,
302
+ ]
303
+ if os.environ.get(HF_GSK_HUB_KEY):
304
+ command.append("--giskard_hub_api_key")
305
+ command.append(os.environ.get(HF_GSK_HUB_KEY))
306
+ if os.environ.get(HF_GSK_HUB_URL):
307
+ command.append("--giskard_hub_url")
308
+ command.append(os.environ.get(HF_GSK_HUB_URL))
309
+ if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
310
+ command.append("--giskard_hub_project_key")
311
+ command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
312
+ if os.environ.get(HF_GSK_HUB_HF_TOKEN):
313
+ command.append("--giskard_hub_hf_token")
314
+ command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
315
+ if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
316
+ command.append("--giskard_hub_unlock_token")
317
+ command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))
318
+
319
+ eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
320
+ logging.info(f"Start local evaluation on {eval_str}")
321
+ save_job_to_pipe(uid, command, eval_str, threading.Lock())
322
+ print(command)
323
+ write_log_to_user_file(
324
+ uid,
325
+ f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
326
+ )
327
+ gr.Info(f"Start local evaluation on {eval_str}")
328
 
329
+ return (
330
+ gr.update(interactive=False),
331
+ gr.update(lines=5, visible=True, interactive=False),
332
+ )
333
 
 
 
334
 
335
+ # TODO: Submit task to an endpoint")
336
+
337
+ # return (gr.update(interactive=True), gr.update(visible=False)) # Submit button