Spaces:
Runtime error
Runtime error
Add NLI support
Browse files
app.py
CHANGED
@@ -36,6 +36,7 @@ TASK_TO_ID = {
|
|
36 |
"image_multi_class_classification": 18,
|
37 |
"binary_classification": 1,
|
38 |
"multi_class_classification": 2,
|
|
|
39 |
"entity_extraction": 4,
|
40 |
"extractive_question_answering": 5,
|
41 |
"translation": 6,
|
@@ -50,6 +51,7 @@ TASK_TO_DEFAULT_METRICS = {
|
|
50 |
"recall",
|
51 |
"accuracy",
|
52 |
],
|
|
|
53 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
54 |
"extractive_question_answering": ["f1", "exact_match"],
|
55 |
"translation": ["sacrebleu"],
|
@@ -117,11 +119,19 @@ SUPPORTED_METRICS = [
|
|
117 |
"jordyvl/ece",
|
118 |
"lvwerra/ai4code",
|
119 |
"lvwerra/amex",
|
120 |
-
"lvwerra/test",
|
121 |
-
"lvwerra/test_metric",
|
122 |
]
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
#######
|
126 |
# APP #
|
127 |
#######
|
@@ -269,6 +279,47 @@ with st.expander("Advanced configuration"):
|
|
269 |
col_mapping[text_col] = "text"
|
270 |
col_mapping[target_col] = "target"
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
elif selected_task == "entity_extraction":
|
273 |
with col1:
|
274 |
st.markdown("`tokens` column")
|
|
|
36 |
"image_multi_class_classification": 18,
|
37 |
"binary_classification": 1,
|
38 |
"multi_class_classification": 2,
|
39 |
+
"natural_language_inference": 22,
|
40 |
"entity_extraction": 4,
|
41 |
"extractive_question_answering": 5,
|
42 |
"translation": 6,
|
|
|
51 |
"recall",
|
52 |
"accuracy",
|
53 |
],
|
54 |
+
"natural_language_inference": ["f1", "precision", "recall", "auc", "accuracy"],
|
55 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
56 |
"extractive_question_answering": ["f1", "exact_match"],
|
57 |
"translation": ["sacrebleu"],
|
|
|
119 |
"jordyvl/ece",
|
120 |
"lvwerra/ai4code",
|
121 |
"lvwerra/amex",
|
|
|
|
|
122 |
]
|
123 |
|
124 |
|
125 |
+
def get_config_metadata(config, metadata=None):
|
126 |
+
if metadata is None:
|
127 |
+
return None
|
128 |
+
config_metadata = [m for m in metadata if m["config"] == config]
|
129 |
+
if len(config_metadata) == 1:
|
130 |
+
return config_metadata[0]
|
131 |
+
else:
|
132 |
+
return None
|
133 |
+
|
134 |
+
|
135 |
#######
|
136 |
# APP #
|
137 |
#######
|
|
|
279 |
col_mapping[text_col] = "text"
|
280 |
col_mapping[target_col] = "target"
|
281 |
|
282 |
+
col_mapping = {}
|
283 |
+
if selected_task in ["natural_language_inference"]:
|
284 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
285 |
+
with col1:
|
286 |
+
st.markdown("`text1` column")
|
287 |
+
st.text("")
|
288 |
+
st.text("")
|
289 |
+
st.text("")
|
290 |
+
st.text("")
|
291 |
+
st.markdown("`text2` column")
|
292 |
+
st.text("")
|
293 |
+
st.text("")
|
294 |
+
st.text("")
|
295 |
+
st.text("")
|
296 |
+
st.markdown("`target` column")
|
297 |
+
with col2:
|
298 |
+
text1_col = st.selectbox(
|
299 |
+
"This column should contain the first text passage to be classified",
|
300 |
+
col_names,
|
301 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text1"))
|
302 |
+
if config_metadata is not None
|
303 |
+
else 0,
|
304 |
+
)
|
305 |
+
text2_col = st.selectbox(
|
306 |
+
"This column should contain the second text passage to be classified",
|
307 |
+
col_names,
|
308 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text2"))
|
309 |
+
if config_metadata is not None
|
310 |
+
else 0,
|
311 |
+
)
|
312 |
+
target_col = st.selectbox(
|
313 |
+
"This column should contain the labels associated with the text",
|
314 |
+
col_names,
|
315 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
316 |
+
if config_metadata is not None
|
317 |
+
else 0,
|
318 |
+
)
|
319 |
+
col_mapping[text1_col] = "text1"
|
320 |
+
col_mapping[text2_col] = "text2"
|
321 |
+
col_mapping[target_col] = "target"
|
322 |
+
|
323 |
elif selected_task == "entity_extraction":
|
324 |
with col1:
|
325 |
st.markdown("`tokens` column")
|
utils.py
CHANGED
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
13 |
"binary_classification": "text-classification",
|
14 |
"multi_class_classification": "text-classification",
|
|
|
15 |
"entity_extraction": "token-classification",
|
16 |
"extractive_question_answering": "question-answering",
|
17 |
"translation": "translation",
|
|
|
12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
13 |
"binary_classification": "text-classification",
|
14 |
"multi_class_classification": "text-classification",
|
15 |
+
"natural_language_inference": "text-classification",
|
16 |
"entity_extraction": "token-classification",
|
17 |
"extractive_question_answering": "question-answering",
|
18 |
"translation": "translation",
|