ZeroCommand commited on
Commit
201d156
1 Parent(s): d9ca844

handle inference api error; fix not text dataset columns

Browse files
app_text_classification.py CHANGED
@@ -128,7 +128,11 @@ def get_demo():
128
  fn=get_related_datasets_from_leaderboard,
129
  inputs=[model_id_input],
130
  outputs=[dataset_id_input],
131
- ).then(fn=check_dataset, inputs=[dataset_id_input], outputs=[dataset_config_input, dataset_split_input, loading_status])
 
 
 
 
132
 
133
  gr.on(
134
  triggers=[dataset_id_input.input],
 
128
  fn=get_related_datasets_from_leaderboard,
129
  inputs=[model_id_input],
130
  outputs=[dataset_id_input],
131
+ ).then(
132
+ fn=check_dataset,
133
+ inputs=[dataset_id_input],
134
+ outputs=[dataset_config_input, dataset_split_input, loading_status]
135
+ )
136
 
137
  gr.on(
138
  triggers=[dataset_id_input.input],
text_classification.py CHANGED
@@ -9,6 +9,7 @@ import requests
9
  import os
10
  import time
11
 
 
12
  HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
13
 
14
  logger = logging.getLogger(__file__)
@@ -76,19 +77,18 @@ def hf_inference_api(model_id, hf_token, payload):
76
  )
77
  url = f"{hf_inference_api_endpoint}/models/{model_id}"
78
  headers = {"Authorization": f"Bearer {hf_token}"}
79
- output = {"error": "First attemp"}
80
- attempt = 30
81
- while "error" in output and attempt > 0:
82
- response = requests.post(url, headers=headers, json=payload)
83
- if response.status_code != 200:
84
- logging.error(f"Request to inference API returns {response.status_code}")
85
- try:
86
- return response.json()
87
- except Exception:
88
- logging.error(f"{response.content}")
89
- output = {"error": response.content}
90
- attempt -= 1
91
- time.sleep(2)
92
 
93
  def check_model_pipeline(model_id):
94
  try:
@@ -262,6 +262,12 @@ def check_dataset_features_validity(d_id, config, split):
262
 
263
  return df, dataset_features
264
 
 
 
 
 
 
 
265
 
266
  def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
267
  # get a sample prediction from the model on the dataset
@@ -272,13 +278,20 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
272
  ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split]
273
  if "text" not in ds.features.keys():
274
  # Dataset does not have text column
275
- prediction_input = ds[0][list(ds.features.keys())[0]]
276
  else:
277
  prediction_input = ds[0]["text"]
278
-
279
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
280
  payload = {"inputs": prediction_input, "options": {"use_cache": True}}
281
  results = hf_inference_api(model_id, hf_token, payload)
 
 
 
 
 
 
 
282
  while isinstance(results, list):
283
  if isinstance(results[0], dict):
284
  break
@@ -288,8 +301,7 @@ def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
288
  }
289
  except Exception as e:
290
  # Pipeline prediction failed, need to provide labels
291
- logger.warn(f"Pipeline prediction failed due to {e}")
292
- return prediction_input, None
293
 
294
  return prediction_input, prediction_result
295
 
 
9
  import os
10
  import time
11
 
12
+ logger = logging.getLogger(__name__)
13
  HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
14
 
15
  logger = logging.getLogger(__file__)
 
77
  )
78
  url = f"{hf_inference_api_endpoint}/models/{model_id}"
79
  headers = {"Authorization": f"Bearer {hf_token}"}
80
+ response = requests.post(url, headers=headers, json=payload)
81
+ if not hasattr(response, "status_code") or response.status_code != 200:
82
+ logger.warning(f"Request to inference API returns {response}")
83
+ try:
84
+ return response.json()
85
+ except Exception:
86
+ return {"error": response.content}
87
+
88
+ def preload_hf_inference_api(model_id):
89
+ payload = {"inputs": "This is a test", "options": {"use_cache": True, }}
90
+ hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
91
+ hf_inference_api(model_id, hf_token, payload)
 
92
 
93
  def check_model_pipeline(model_id):
94
  try:
 
262
 
263
  return df, dataset_features
264
 
265
+ def select_the_first_string_column(ds):
266
+ for feature in ds.features.keys():
267
+ if isinstance(ds[0][feature], str):
268
+ return feature
269
+ return None
270
+
271
 
272
  def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split):
273
  # get a sample prediction from the model on the dataset
 
278
  ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split]
279
  if "text" not in ds.features.keys():
280
  # Dataset does not have text column
281
+ prediction_input = ds[0][select_the_first_string_column(ds)]
282
  else:
283
  prediction_input = ds[0]["text"]
284
+
285
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
286
  payload = {"inputs": prediction_input, "options": {"use_cache": True}}
287
  results = hf_inference_api(model_id, hf_token, payload)
288
+
289
+ if isinstance(results, dict) and "estimated_time" in results.keys():
290
+ return prediction_input, str(results["estimated_time"])
291
+
292
+ if isinstance(results, dict) and "error" in results.keys():
293
+ raise ValueError(results["error"])
294
+
295
  while isinstance(results, list):
296
  if isinstance(results[0], dict):
297
  break
 
301
  }
302
  except Exception as e:
303
  # Pipeline prediction failed, need to provide labels
304
+ return prediction_input, e
 
305
 
306
  return prediction_input, prediction_result
307
 
text_classification_ui_helpers.py CHANGED
@@ -12,6 +12,7 @@ from io_utils import read_column_mapping, write_column_mapping
12
  from run_jobs import save_job_to_pipe
13
  from text_classification import (
14
  check_model_task,
 
15
  get_example_prediction,
16
  get_labels_and_features_from_dataset,
17
  )
@@ -159,9 +160,10 @@ def precheck_model_ds_enable_example_btn(
159
  model_id, dataset_id, dataset_config, dataset_split
160
  ):
161
  model_task = check_model_task(model_id)
 
162
  if model_task is None or model_task != "text-classification":
163
  gr.Warning("Please check your model.")
164
- return gr.update(interactive=False), ""
165
 
166
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
167
  return (gr.update(), gr.update(), "")
@@ -182,8 +184,6 @@ def precheck_model_ds_enable_example_btn(
182
  return (gr.update(interactive=False), gr.update(value=pd.DataFrame(), visible=False), "")
183
 
184
 
185
-
186
-
187
  def align_columns_and_show_prediction(
188
  model_id,
189
  dataset_id,
@@ -209,12 +209,32 @@ def align_columns_and_show_prediction(
209
  gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
210
  ]
211
 
212
- # FIXME: prefiction_output could be None
213
- prediction_input, prediction_output = get_example_prediction(
214
  model_id, dataset_id, dataset_config, dataset_split
215
  )
216
 
217
- model_labels = list(prediction_output.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split]
220
  ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
@@ -255,7 +275,7 @@ def align_columns_and_show_prediction(
255
 
256
  return (
257
  gr.update(value=get_styled_input(prediction_input), visible=True),
258
- gr.update(value=prediction_output, visible=True),
259
  gr.update(visible=True, open=False),
260
  gr.update(interactive=(run_inference and inference_token != "")),
261
  "",
 
12
  from run_jobs import save_job_to_pipe
13
  from text_classification import (
14
  check_model_task,
15
+ preload_hf_inference_api,
16
  get_example_prediction,
17
  get_labels_and_features_from_dataset,
18
  )
 
160
  model_id, dataset_id, dataset_config, dataset_split
161
  ):
162
  model_task = check_model_task(model_id)
163
+ preload_hf_inference_api(model_id)
164
  if model_task is None or model_task != "text-classification":
165
  gr.Warning("Please check your model.")
166
+ return (gr.update(), gr.update(),"")
167
 
168
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
169
  return (gr.update(), gr.update(), "")
 
184
  return (gr.update(interactive=False), gr.update(value=pd.DataFrame(), visible=False), "")
185
 
186
 
 
 
187
  def align_columns_and_show_prediction(
188
  model_id,
189
  dataset_id,
 
209
  gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
210
  ]
211
 
212
+ prediction_input, prediction_response = get_example_prediction(
 
213
  model_id, dataset_id, dataset_config, dataset_split
214
  )
215
 
216
+ if isinstance(prediction_response, str):
217
+ return (
218
+ gr.update(visible=False),
219
+ gr.update(visible=False),
220
+ gr.update(visible=False, open=False),
221
+ gr.update(interactive=False),
222
+ f"Hugging Face Inference API is loading your model, estimation time {prediction_response}",
223
+ *dropdown_placement,
224
+ )
225
+
226
+ if isinstance(prediction_response, Exception):
227
+ gr.Warning("Please check your model or Hugging Face token.")
228
+ return (
229
+ gr.update(visible=False),
230
+ gr.update(visible=False),
231
+ gr.update(visible=False, open=False),
232
+ gr.update(interactive=False),
233
+ f"Sorry, inference api loading error {prediction_response}, please check your model and token.",
234
+ *dropdown_placement,
235
+ )
236
+
237
+ model_labels = list(prediction_response.keys())
238
 
239
  ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split]
240
  ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
 
275
 
276
  return (
277
  gr.update(value=get_styled_input(prediction_input), visible=True),
278
+ gr.update(value=prediction_response, visible=True),
279
  gr.update(visible=True, open=False),
280
  gr.update(interactive=(run_inference and inference_token != "")),
281
  "",