Spaces:
Sleeping
Sleeping
Commit
•
2694247
1
Parent(s):
1204664
GSK-2909-handle-input-too-long-error (#165)
Browse files- fix labels not found; handle input too long (6457b7ab3e357b83e70bad7268307a81d2154bfa)
- Remove leading spaces (b709c217e8e05020af6b864792c2941358e29dfc)
- Remove tailing spaces (6dbdec2b3e7c630a187b0171e19325eb4e2f1126)
- Remove leading spaces (c1a5e7ecbc5505f334a5bf1a458bf1d7ffaa00e3)
Co-authored-by: zcy <ZeroCommand@users.noreply.huggingface.co>
- app_text_classification.py +1 -0
- text_classification.py +12 -0
- text_classification_ui_helpers.py +3 -3
app_text_classification.py
CHANGED
@@ -201,6 +201,7 @@ def get_demo():
|
|
201 |
gr.on(
|
202 |
triggers=[
|
203 |
model_id_input.change,
|
|
|
204 |
dataset_id_input.change,
|
205 |
dataset_config_input.change,
|
206 |
dataset_split_input.change,
|
|
|
201 |
gr.on(
|
202 |
triggers=[
|
203 |
model_id_input.change,
|
204 |
+
model_id_input.input,
|
205 |
dataset_id_input.change,
|
206 |
dataset_config_input.change,
|
207 |
dataset_split_input.change,
|
text_classification.py
CHANGED
@@ -28,10 +28,14 @@ def get_labels_and_features_from_dataset(ds):
|
|
28 |
if len(label_keys) == 0: # no labels found
|
29 |
# return everything for post processing
|
30 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
|
|
|
|
31 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
32 |
if hasattr(dataset_features[label_keys[0]], "feature"):
|
33 |
label_feat = dataset_features[label_keys[0]].feature
|
34 |
labels = label_feat.names
|
|
|
|
|
35 |
else:
|
36 |
labels = dataset_features[label_keys[0]].names
|
37 |
return labels, features, label_keys
|
@@ -83,9 +87,17 @@ def hf_inference_api(model_id, hf_token, payload):
|
|
83 |
url = f"{hf_inference_api_endpoint}/models/{model_id}"
|
84 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
85 |
response = requests.post(url, headers=headers, json=payload)
|
|
|
86 |
if not hasattr(response, "status_code") or response.status_code != 200:
|
87 |
logger.warning(f"Request to inference API returns {response}")
|
|
|
88 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
return response.json()
|
90 |
except Exception:
|
91 |
return {"error": response.content}
|
|
|
28 |
if len(label_keys) == 0: # no labels found
|
29 |
# return everything for post processing
|
30 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
31 |
+
|
32 |
+
labels = None
|
33 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
34 |
if hasattr(dataset_features[label_keys[0]], "feature"):
|
35 |
label_feat = dataset_features[label_keys[0]].feature
|
36 |
labels = label_feat.names
|
37 |
+
else:
|
38 |
+
labels = ds.unique(label_keys[0])
|
39 |
else:
|
40 |
labels = dataset_features[label_keys[0]].names
|
41 |
return labels, features, label_keys
|
|
|
87 |
url = f"{hf_inference_api_endpoint}/models/{model_id}"
|
88 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
89 |
response = requests.post(url, headers=headers, json=payload)
|
90 |
+
|
91 |
if not hasattr(response, "status_code") or response.status_code != 200:
|
92 |
logger.warning(f"Request to inference API returns {response}")
|
93 |
+
|
94 |
try:
|
95 |
+
output = response.json()
|
96 |
+
if "error" in output and "Input is too long" in output["error"]:
|
97 |
+
payload.update({"parameters": {"truncation": True, "max_length": 512}})
|
98 |
+
response = requests.post(url, headers=headers, json=payload)
|
99 |
+
if not hasattr(response, "status_code") or response.status_code != 200:
|
100 |
+
logger.warning(f"Request to inference API returns {response}")
|
101 |
return response.json()
|
102 |
except Exception:
|
103 |
return {"error": response.content}
|
text_classification_ui_helpers.py
CHANGED
@@ -341,8 +341,8 @@ def align_columns_and_show_prediction(
|
|
341 |
):
|
342 |
return (
|
343 |
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
|
344 |
-
gr.update(visible=
|
345 |
-
gr.update(visible=
|
346 |
gr.update(visible=True, open=True),
|
347 |
gr.update(interactive=(run_inference and inference_token != "")),
|
348 |
"",
|
@@ -351,7 +351,7 @@ def align_columns_and_show_prediction(
|
|
351 |
|
352 |
return (
|
353 |
gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
|
354 |
-
gr.update(value=prediction_input, lines=len(prediction_input)//225 + 1, visible=True),
|
355 |
gr.update(value=prediction_response, visible=True),
|
356 |
gr.update(visible=True, open=False),
|
357 |
gr.update(interactive=(run_inference and inference_token != "")),
|
|
|
341 |
):
|
342 |
return (
|
343 |
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
|
344 |
+
gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
|
345 |
+
gr.update(value=prediction_response, visible=True),
|
346 |
gr.update(visible=True, open=True),
|
347 |
gr.update(interactive=(run_inference and inference_token != "")),
|
348 |
"",
|
|
|
351 |
|
352 |
return (
|
353 |
gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
|
354 |
+
gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
|
355 |
gr.update(value=prediction_response, visible=True),
|
356 |
gr.update(visible=True, open=False),
|
357 |
gr.update(interactive=(run_inference and inference_token != "")),
|