ZeroCommand commited on
Commit
fc7c452
·
1 Parent(s): 21e0bb3

move inference api parameters

Browse files
app_text_classification.py CHANGED
@@ -2,8 +2,7 @@ import uuid
2
 
3
  import gradio as gr
4
 
5
- from io_utils import (get_logs_file, read_inference_type, read_scanners,
6
- write_inference_type, write_scanners)
7
  from text_classification_ui_helpers import (check_dataset_and_get_config,
8
  check_dataset_and_get_split,
9
  check_model_and_show_prediction,
@@ -65,15 +64,7 @@ def get_demo():
65
 
66
  with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
67
  run_local = gr.Checkbox(value=True, label="Run in this Space")
68
- run_inference = gr.Checkbox(value="False", label="Run with Inference API")
69
-
70
- @gr.on(triggers=[uid_label.change], inputs=[uid_label], outputs=[run_inference])
71
- def get_run_mode(uid):
72
- return gr.update(
73
- value=read_inference_type(uid) == "hf_inference_api"
74
- and not run_local.value
75
- )
76
-
77
  inference_token = gr.Textbox(
78
  value="",
79
  label="HF Token for Inference API",
@@ -122,7 +113,7 @@ def get_demo():
122
 
123
  run_inference.change(
124
  select_run_mode,
125
- inputs=[run_inference, inference_token, uid_label],
126
  outputs=[inference_token, run_local],
127
  )
128
 
@@ -131,11 +122,7 @@ def get_demo():
131
  inputs=[run_local],
132
  outputs=[inference_token, run_inference],
133
  )
134
-
135
- inference_token.change(
136
- write_inference_type, inputs=[run_inference, inference_token, uid_label]
137
- )
138
-
139
  gr.on(
140
  triggers=[label.change for label in column_mappings],
141
  fn=write_column_mapping_to_config,
@@ -189,6 +176,8 @@ def get_demo():
189
  dataset_config_input,
190
  dataset_split_input,
191
  run_local,
 
 
192
  uid_label,
193
  ],
194
  outputs=[run_btn, logs],
@@ -204,6 +193,7 @@ def get_demo():
204
  dataset_split_input.change,
205
  run_inference.change,
206
  run_local.change,
 
207
  scanners.change,
208
  ],
209
  fn=enable_run_btn,
 
2
 
3
  import gradio as gr
4
 
5
+ from io_utils import (get_logs_file, read_scanners, write_scanners)
 
6
  from text_classification_ui_helpers import (check_dataset_and_get_config,
7
  check_dataset_and_get_split,
8
  check_model_and_show_prediction,
 
64
 
65
  with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
66
  run_local = gr.Checkbox(value=True, label="Run in this Space")
67
+ run_inference = gr.Checkbox(value=False, label="Run with Inference API")
 
 
 
 
 
 
 
 
68
  inference_token = gr.Textbox(
69
  value="",
70
  label="HF Token for Inference API",
 
113
 
114
  run_inference.change(
115
  select_run_mode,
116
+ inputs=[run_inference],
117
  outputs=[inference_token, run_local],
118
  )
119
 
 
122
  inputs=[run_local],
123
  outputs=[inference_token, run_inference],
124
  )
125
+
 
 
 
 
126
  gr.on(
127
  triggers=[label.change for label in column_mappings],
128
  fn=write_column_mapping_to_config,
 
176
  dataset_config_input,
177
  dataset_split_input,
178
  run_local,
179
+ run_inference,
180
+ inference_token,
181
  uid_label,
182
  ],
183
  outputs=[run_btn, logs],
 
193
  dataset_split_input.change,
194
  run_inference.change,
195
  run_local.change,
196
+ inference_token.change,
197
  scanners.change,
198
  ],
199
  fn=enable_run_btn,
text_classification_ui_helpers.py CHANGED
@@ -10,7 +10,7 @@ from transformers.pipelines import TextClassificationPipeline
10
  from wordings import get_styled_input
11
 
12
  from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe,
13
- write_column_mapping, write_inference_type,
14
  write_log_to_user_file)
15
  from text_classification import (check_model, get_example_prediction,
16
  get_labels_and_features_from_dataset)
@@ -51,10 +51,8 @@ def check_dataset_and_get_split(dataset_id, dataset_config):
51
  pass
52
 
53
 
54
- def select_run_mode(run_inf, inf_token, uid):
55
  if run_inf:
56
- if len(inf_token) > 0:
57
- write_inference_type(run_inf, inf_token, uid)
58
  return (gr.update(visible=True), gr.update(value=False))
59
  else:
60
  return (gr.update(visible=False), gr.update(value=True))
@@ -213,10 +211,7 @@ def check_model_and_show_prediction(
213
  *column_mappings,
214
  )
215
 
216
-
217
- def try_submit(m_id, d_id, config, split, local, uid):
218
- all_mappings = read_column_mapping(uid)
219
-
220
  if all_mappings is None:
221
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
222
  return (gr.update(interactive=True), gr.update(visible=False))
@@ -224,6 +219,8 @@ def try_submit(m_id, d_id, config, split, local, uid):
224
  if "labels" not in all_mappings.keys():
225
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
226
  return (gr.update(interactive=True), gr.update(visible=False))
 
 
227
  label_mapping = {}
228
  for i, label in zip(
229
  range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
@@ -234,73 +231,88 @@ def try_submit(m_id, d_id, config, split, local, uid):
234
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
235
  return (gr.update(interactive=True), gr.update(visible=False))
236
  feature_mapping = all_mappings["features"]
 
 
 
 
 
 
237
 
238
  leaderboard_dataset = None
239
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
240
  leaderboard_dataset = "ZeroCommand/test-giskard-report"
 
 
 
 
 
241
 
242
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
243
- if local:
244
- command = [
245
- "giskard_scanner",
246
- "--loader",
247
- "huggingface",
248
- "--model",
249
- m_id,
250
- "--dataset",
251
- d_id,
252
- "--dataset_config",
253
- config,
254
- "--dataset_split",
255
- split,
256
- "--hf_token",
257
- os.environ.get(HF_WRITE_TOKEN),
258
- "--discussion_repo",
259
- os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
260
- "--output_format",
261
- "markdown",
262
- "--output_portal",
263
- "huggingface",
264
- "--feature_mapping",
265
- json.dumps(feature_mapping),
266
- "--label_mapping",
267
- json.dumps(label_mapping),
268
- "--scan_config",
269
- get_yaml_path(uid),
270
- "--leaderboard_dataset",
271
- leaderboard_dataset,
272
- ]
273
- if os.environ.get(HF_GSK_HUB_KEY):
274
- command.append("--giskard_hub_api_key")
275
- command.append(os.environ.get(HF_GSK_HUB_KEY))
276
- if os.environ.get(HF_GSK_HUB_URL):
277
- command.append("--giskard_hub_url")
278
- command.append(os.environ.get(HF_GSK_HUB_URL))
279
- if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
280
- command.append("--giskard_hub_project_key")
281
- command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
282
- if os.environ.get(HF_GSK_HUB_HF_TOKEN):
283
- command.append("--giskard_hub_hf_token")
284
- command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
285
- if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
286
- command.append("--giskard_hub_unlock_token")
287
- command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))
288
-
289
- eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
290
- logging.info(f"Start local evaluation on {eval_str}")
291
- save_job_to_pipe(uid, command, eval_str, threading.Lock())
292
- write_log_to_user_file(
293
- uid,
294
- f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
295
- )
296
- gr.Info(f"Start local evaluation on {eval_str}")
 
 
 
 
297
 
298
- return (
299
- gr.update(interactive=False),
300
- gr.update(lines=5, visible=True, interactive=False),
301
- )
302
 
303
- else:
304
- gr.Info("TODO: Submit task to an endpoint")
305
 
306
- return (gr.update(interactive=True), gr.update(visible=False)) # Submit button
 
 
 
10
  from wordings import get_styled_input
11
 
12
  from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe,
13
+ write_column_mapping,
14
  write_log_to_user_file)
15
  from text_classification import (check_model, get_example_prediction,
16
  get_labels_and_features_from_dataset)
 
51
  pass
52
 
53
 
54
+ def select_run_mode(run_inf):
55
  if run_inf:
 
 
56
  return (gr.update(visible=True), gr.update(value=False))
57
  else:
58
  return (gr.update(visible=False), gr.update(value=True))
 
211
  *column_mappings,
212
  )
213
 
214
+ def check_column_mapping_keys_validity(all_mappings):
 
 
 
215
  if all_mappings is None:
216
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
217
  return (gr.update(interactive=True), gr.update(visible=False))
 
219
  if "labels" not in all_mappings.keys():
220
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
221
  return (gr.update(interactive=True), gr.update(visible=False))
222
+
223
+ def construct_label_and_feature_mapping(all_mappings):
224
  label_mapping = {}
225
  for i, label in zip(
226
  range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
 
231
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
232
  return (gr.update(interactive=True), gr.update(visible=False))
233
  feature_mapping = all_mappings["features"]
234
+ return label_mapping, feature_mapping
235
+
236
+ def try_submit(m_id, d_id, config, split, local, inference, inference_token, uid):
237
+ all_mappings = read_column_mapping(uid)
238
+ check_column_mapping_keys_validity(all_mappings)
239
+ label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings)
240
 
241
  leaderboard_dataset = None
242
  if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
243
  leaderboard_dataset = "ZeroCommand/test-giskard-report"
244
+
245
+ if local:
246
+ inference_type = "hf_pipeline"
247
+ if inference and inference_token:
248
+ inference_type = "hf_inference_api"
249
 
250
  # TODO: Set column mapping for some dataset such as `amazon_polarity`
251
+ command = [
252
+ "giskard_scanner",
253
+ "--loader",
254
+ "huggingface",
255
+ "--model",
256
+ m_id,
257
+ "--dataset",
258
+ d_id,
259
+ "--dataset_config",
260
+ config,
261
+ "--dataset_split",
262
+ split,
263
+ "--hf_token",
264
+ os.environ.get(HF_WRITE_TOKEN),
265
+ "--discussion_repo",
266
+ os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
267
+ "--output_format",
268
+ "markdown",
269
+ "--output_portal",
270
+ "huggingface",
271
+ "--feature_mapping",
272
+ json.dumps(feature_mapping),
273
+ "--label_mapping",
274
+ json.dumps(label_mapping),
275
+ "--scan_config",
276
+ get_yaml_path(uid),
277
+ "--leaderboard_dataset",
278
+ leaderboard_dataset,
279
+ "--inference_type",
280
+ inference_type,
281
+ "--inference_token",
282
+ inference_token,
283
+ ]
284
+ if os.environ.get(HF_GSK_HUB_KEY):
285
+ command.append("--giskard_hub_api_key")
286
+ command.append(os.environ.get(HF_GSK_HUB_KEY))
287
+ if os.environ.get(HF_GSK_HUB_URL):
288
+ command.append("--giskard_hub_url")
289
+ command.append(os.environ.get(HF_GSK_HUB_URL))
290
+ if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
291
+ command.append("--giskard_hub_project_key")
292
+ command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
293
+ if os.environ.get(HF_GSK_HUB_HF_TOKEN):
294
+ command.append("--giskard_hub_hf_token")
295
+ command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
296
+ if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
297
+ command.append("--giskard_hub_unlock_token")
298
+ command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))
299
+
300
+ eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
301
+ logging.info(f"Start local evaluation on {eval_str}")
302
+ save_job_to_pipe(uid, command, eval_str, threading.Lock())
303
+ print(command)
304
+ write_log_to_user_file(
305
+ uid,
306
+ f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
307
+ )
308
+ gr.Info(f"Start local evaluation on {eval_str}")
309
 
310
+ return (
311
+ gr.update(interactive=False),
312
+ gr.update(lines=5, visible=True, interactive=False),
313
+ )
314
 
 
 
315
 
316
+ # TODO: Submit task to an endpoint")
317
+
318
+ # return (gr.update(interactive=True), gr.update(visible=False)) # Submit button