GSK-2909-handle-input-too-long-error

#165
app_text_classification.py CHANGED
@@ -201,6 +201,7 @@ def get_demo():
201
  gr.on(
202
  triggers=[
203
  model_id_input.change,
 
204
  dataset_id_input.change,
205
  dataset_config_input.change,
206
  dataset_split_input.change,
 
201
  gr.on(
202
  triggers=[
203
  model_id_input.change,
204
+ model_id_input.input,
205
  dataset_id_input.change,
206
  dataset_config_input.change,
207
  dataset_split_input.change,
text_classification.py CHANGED
@@ -28,10 +28,14 @@ def get_labels_and_features_from_dataset(ds):
28
  if len(label_keys) == 0: # no labels found
29
  # return everything for post processing
30
  return list(dataset_features.keys()), list(dataset_features.keys()), None
 
 
31
  if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
32
  if hasattr(dataset_features[label_keys[0]], "feature"):
33
  label_feat = dataset_features[label_keys[0]].feature
34
  labels = label_feat.names
 
 
35
  else:
36
  labels = dataset_features[label_keys[0]].names
37
  return labels, features, label_keys
@@ -83,9 +87,17 @@ def hf_inference_api(model_id, hf_token, payload):
83
  url = f"{hf_inference_api_endpoint}/models/{model_id}"
84
  headers = {"Authorization": f"Bearer {hf_token}"}
85
  response = requests.post(url, headers=headers, json=payload)
 
86
  if not hasattr(response, "status_code") or response.status_code != 200:
87
  logger.warning(f"Request to inference API returns {response}")
 
88
  try:
 
 
 
 
 
 
89
  return response.json()
90
  except Exception:
91
  return {"error": response.content}
 
28
  if len(label_keys) == 0: # no labels found
29
  # return everything for post processing
30
  return list(dataset_features.keys()), list(dataset_features.keys()), None
31
+
32
+ labels = None
33
  if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
34
  if hasattr(dataset_features[label_keys[0]], "feature"):
35
  label_feat = dataset_features[label_keys[0]].feature
36
  labels = label_feat.names
37
+ else:
38
+ labels = ds.unique(label_keys[0])
39
  else:
40
  labels = dataset_features[label_keys[0]].names
41
  return labels, features, label_keys
 
87
  url = f"{hf_inference_api_endpoint}/models/{model_id}"
88
  headers = {"Authorization": f"Bearer {hf_token}"}
89
  response = requests.post(url, headers=headers, json=payload)
90
+
91
  if not hasattr(response, "status_code") or response.status_code != 200:
92
  logger.warning(f"Request to inference API returns {response}")
93
+
94
  try:
95
+ output = response.json()
96
+ if "error" in output and "Input is too long" in output["error"]:
97
+ payload.update({"parameters": {"truncation": True, "max_length": 512}})
98
+ response = requests.post(url, headers=headers, json=payload)
99
+ if not hasattr(response, "status_code") or response.status_code != 200:
100
+ logger.warning(f"Request to inference API returns {response}")
101
  return response.json()
102
  except Exception:
103
  return {"error": response.content}
text_classification_ui_helpers.py CHANGED
@@ -341,8 +341,8 @@ def align_columns_and_show_prediction(
341
  ):
342
  return (
343
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
344
- gr.update(visible=False),
345
- gr.update(visible=False),
346
  gr.update(visible=True, open=True),
347
  gr.update(interactive=(run_inference and inference_token != "")),
348
  "",
@@ -351,7 +351,7 @@ def align_columns_and_show_prediction(
351
 
352
  return (
353
  gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
354
- gr.update(value=prediction_input, lines=len(prediction_input)//225 + 1, visible=True),
355
  gr.update(value=prediction_response, visible=True),
356
  gr.update(visible=True, open=False),
357
  gr.update(interactive=(run_inference and inference_token != "")),
 
341
  ):
342
  return (
343
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
344
+ gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
345
+ gr.update(value=prediction_response, visible=True),
346
  gr.update(visible=True, open=True),
347
  gr.update(interactive=(run_inference and inference_token != "")),
348
  "",
 
351
 
352
  return (
353
  gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
354
+ gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
355
  gr.update(value=prediction_response, visible=True),
356
  gr.update(visible=True, open=False),
357
  gr.update(interactive=(run_inference and inference_token != "")),