Spaces:
Runtime error
Runtime error
Merge pull request #51 from huggingface/add-nli
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ from utils import (
|
|
16 |
create_autotrain_project_name,
|
17 |
format_col_mapping,
|
18 |
get_compatible_models,
|
|
|
19 |
get_dataset_card_url,
|
20 |
get_key,
|
21 |
get_metadata,
|
@@ -37,6 +38,7 @@ TASK_TO_ID = {
|
|
37 |
"image_multi_class_classification": 18,
|
38 |
"binary_classification": 1,
|
39 |
"multi_class_classification": 2,
|
|
|
40 |
"entity_extraction": 4,
|
41 |
"extractive_question_answering": 5,
|
42 |
"translation": 6,
|
@@ -51,6 +53,7 @@ TASK_TO_DEFAULT_METRICS = {
|
|
51 |
"recall",
|
52 |
"accuracy",
|
53 |
],
|
|
|
54 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
55 |
"extractive_question_answering": ["f1", "exact_match"],
|
56 |
"translation": ["sacrebleu"],
|
@@ -72,7 +75,6 @@ AUTOTRAIN_TASK_TO_LANG = {
|
|
72 |
|
73 |
|
74 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
75 |
-
UNSUPPORTED_TASKS = []
|
76 |
|
77 |
# Extracted from utils.get_supported_metrics
|
78 |
# Hardcoded for now due to speed / caching constraints
|
@@ -118,8 +120,6 @@ SUPPORTED_METRICS = [
|
|
118 |
"jordyvl/ece",
|
119 |
"lvwerra/ai4code",
|
120 |
"lvwerra/amex",
|
121 |
-
"lvwerra/test",
|
122 |
-
"lvwerra/test_metric",
|
123 |
]
|
124 |
|
125 |
|
@@ -180,10 +180,6 @@ if metadata is None:
|
|
180 |
|
181 |
with st.expander("Advanced configuration"):
|
182 |
# Select task
|
183 |
-
# Hack to filter for unsupported tasks
|
184 |
-
# TODO(lewtun): remove this once we have SQuAD metrics support
|
185 |
-
if metadata is not None and metadata[0]["task_id"] in UNSUPPORTED_TASKS:
|
186 |
-
metadata = None
|
187 |
selected_task = st.selectbox(
|
188 |
"Select a task",
|
189 |
SUPPORTED_TASKS,
|
@@ -201,6 +197,9 @@ with st.expander("Advanced configuration"):
|
|
201 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
202 |
""",
|
203 |
)
|
|
|
|
|
|
|
204 |
|
205 |
# Select splits
|
206 |
splits_resp = http_get(
|
@@ -215,8 +214,8 @@ with st.expander("Advanced configuration"):
|
|
215 |
if split["config"] == selected_config:
|
216 |
split_names.append(split["split"])
|
217 |
|
218 |
-
if
|
219 |
-
eval_split =
|
220 |
else:
|
221 |
eval_split = None
|
222 |
selected_split = st.selectbox(
|
@@ -260,16 +259,62 @@ with st.expander("Advanced configuration"):
|
|
260 |
text_col = st.selectbox(
|
261 |
"This column should contain the text to be classified",
|
262 |
col_names,
|
263 |
-
index=col_names.index(get_key(
|
|
|
|
|
264 |
)
|
265 |
target_col = st.selectbox(
|
266 |
"This column should contain the labels associated with the text",
|
267 |
col_names,
|
268 |
-
index=col_names.index(get_key(
|
|
|
|
|
269 |
)
|
270 |
col_mapping[text_col] = "text"
|
271 |
col_mapping[target_col] = "target"
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
elif selected_task == "entity_extraction":
|
274 |
with col1:
|
275 |
st.markdown("`tokens` column")
|
@@ -282,12 +327,16 @@ with st.expander("Advanced configuration"):
|
|
282 |
tokens_col = st.selectbox(
|
283 |
"This column should contain the array of tokens to be classified",
|
284 |
col_names,
|
285 |
-
index=col_names.index(get_key(
|
|
|
|
|
286 |
)
|
287 |
tags_col = st.selectbox(
|
288 |
"This column should contain the labels associated with each part of the text",
|
289 |
col_names,
|
290 |
-
index=col_names.index(get_key(
|
|
|
|
|
291 |
)
|
292 |
col_mapping[tokens_col] = "tokens"
|
293 |
col_mapping[tags_col] = "tags"
|
@@ -304,12 +353,16 @@ with st.expander("Advanced configuration"):
|
|
304 |
text_col = st.selectbox(
|
305 |
"This column should contain the text to be translated",
|
306 |
col_names,
|
307 |
-
index=col_names.index(get_key(
|
|
|
|
|
308 |
)
|
309 |
target_col = st.selectbox(
|
310 |
"This column should contain the target translation",
|
311 |
col_names,
|
312 |
-
index=col_names.index(get_key(
|
|
|
|
|
313 |
)
|
314 |
col_mapping[text_col] = "source"
|
315 |
col_mapping[target_col] = "target"
|
@@ -326,19 +379,23 @@ with st.expander("Advanced configuration"):
|
|
326 |
text_col = st.selectbox(
|
327 |
"This column should contain the text to be summarized",
|
328 |
col_names,
|
329 |
-
index=col_names.index(get_key(
|
|
|
|
|
330 |
)
|
331 |
target_col = st.selectbox(
|
332 |
"This column should contain the target summary",
|
333 |
col_names,
|
334 |
-
index=col_names.index(get_key(
|
|
|
|
|
335 |
)
|
336 |
col_mapping[text_col] = "text"
|
337 |
col_mapping[target_col] = "target"
|
338 |
|
339 |
elif selected_task == "extractive_question_answering":
|
340 |
-
if
|
341 |
-
col_mapping =
|
342 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
343 |
col_mapping = format_col_mapping(col_mapping)
|
344 |
with col1:
|
@@ -362,22 +419,24 @@ with st.expander("Advanced configuration"):
|
|
362 |
context_col = st.selectbox(
|
363 |
"This column should contain the question's context",
|
364 |
col_names,
|
365 |
-
index=col_names.index(get_key(col_mapping, "context")) if
|
366 |
)
|
367 |
question_col = st.selectbox(
|
368 |
"This column should contain the question to be answered, given the context",
|
369 |
col_names,
|
370 |
-
index=col_names.index(get_key(col_mapping, "question")) if
|
371 |
)
|
372 |
answers_text_col = st.selectbox(
|
373 |
"This column should contain example answers to the question, extracted from the context",
|
374 |
col_names,
|
375 |
-
index=col_names.index(get_key(col_mapping, "answers.text")) if
|
376 |
)
|
377 |
answers_start_col = st.selectbox(
|
378 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
379 |
col_names,
|
380 |
-
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
|
|
|
|
381 |
)
|
382 |
col_mapping[context_col] = "context"
|
383 |
col_mapping[question_col] = "question"
|
@@ -395,12 +454,16 @@ with st.expander("Advanced configuration"):
|
|
395 |
image_col = st.selectbox(
|
396 |
"This column should contain the images to be classified",
|
397 |
col_names,
|
398 |
-
index=col_names.index(get_key(
|
|
|
|
|
399 |
)
|
400 |
target_col = st.selectbox(
|
401 |
"This column should contain the labels associated with the images",
|
402 |
col_names,
|
403 |
-
index=col_names.index(get_key(
|
|
|
|
|
404 |
)
|
405 |
col_mapping[image_col] = "image"
|
406 |
col_mapping[target_col] = "target"
|
|
|
16 |
create_autotrain_project_name,
|
17 |
format_col_mapping,
|
18 |
get_compatible_models,
|
19 |
+
get_config_metadata,
|
20 |
get_dataset_card_url,
|
21 |
get_key,
|
22 |
get_metadata,
|
|
|
38 |
"image_multi_class_classification": 18,
|
39 |
"binary_classification": 1,
|
40 |
"multi_class_classification": 2,
|
41 |
+
"natural_language_inference": 22,
|
42 |
"entity_extraction": 4,
|
43 |
"extractive_question_answering": 5,
|
44 |
"translation": 6,
|
|
|
53 |
"recall",
|
54 |
"accuracy",
|
55 |
],
|
56 |
+
"natural_language_inference": ["f1", "precision", "recall", "auc", "accuracy"],
|
57 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
58 |
"extractive_question_answering": ["f1", "exact_match"],
|
59 |
"translation": ["sacrebleu"],
|
|
|
75 |
|
76 |
|
77 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
|
|
78 |
|
79 |
# Extracted from utils.get_supported_metrics
|
80 |
# Hardcoded for now due to speed / caching constraints
|
|
|
120 |
"jordyvl/ece",
|
121 |
"lvwerra/ai4code",
|
122 |
"lvwerra/amex",
|
|
|
|
|
123 |
]
|
124 |
|
125 |
|
|
|
180 |
|
181 |
with st.expander("Advanced configuration"):
|
182 |
# Select task
|
|
|
|
|
|
|
|
|
183 |
selected_task = st.selectbox(
|
184 |
"Select a task",
|
185 |
SUPPORTED_TASKS,
|
|
|
197 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
198 |
""",
|
199 |
)
|
200 |
+
# Some datasets have multiple metadata (one per config), so we grab the one associated with the selected config
|
201 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
202 |
+
print(f"INFO -- Config metadata: {config_metadata}")
|
203 |
|
204 |
# Select splits
|
205 |
splits_resp = http_get(
|
|
|
214 |
if split["config"] == selected_config:
|
215 |
split_names.append(split["split"])
|
216 |
|
217 |
+
if config_metadata is not None:
|
218 |
+
eval_split = config_metadata["splits"].get("eval_split", None)
|
219 |
else:
|
220 |
eval_split = None
|
221 |
selected_split = st.selectbox(
|
|
|
259 |
text_col = st.selectbox(
|
260 |
"This column should contain the text to be classified",
|
261 |
col_names,
|
262 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
263 |
+
if config_metadata is not None
|
264 |
+
else 0,
|
265 |
)
|
266 |
target_col = st.selectbox(
|
267 |
"This column should contain the labels associated with the text",
|
268 |
col_names,
|
269 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
270 |
+
if config_metadata is not None
|
271 |
+
else 0,
|
272 |
)
|
273 |
col_mapping[text_col] = "text"
|
274 |
col_mapping[target_col] = "target"
|
275 |
|
276 |
+
if selected_task in ["natural_language_inference"]:
|
277 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
278 |
+
with col1:
|
279 |
+
st.markdown("`text1` column")
|
280 |
+
st.text("")
|
281 |
+
st.text("")
|
282 |
+
st.text("")
|
283 |
+
st.text("")
|
284 |
+
st.text("")
|
285 |
+
st.markdown("`text2` column")
|
286 |
+
st.text("")
|
287 |
+
st.text("")
|
288 |
+
st.text("")
|
289 |
+
st.text("")
|
290 |
+
st.text("")
|
291 |
+
st.markdown("`target` column")
|
292 |
+
with col2:
|
293 |
+
text1_col = st.selectbox(
|
294 |
+
"This column should contain the first text passage to be classified",
|
295 |
+
col_names,
|
296 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text1"))
|
297 |
+
if config_metadata is not None
|
298 |
+
else 0,
|
299 |
+
)
|
300 |
+
text2_col = st.selectbox(
|
301 |
+
"This column should contain the second text passage to be classified",
|
302 |
+
col_names,
|
303 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text2"))
|
304 |
+
if config_metadata is not None
|
305 |
+
else 0,
|
306 |
+
)
|
307 |
+
target_col = st.selectbox(
|
308 |
+
"This column should contain the labels associated with the text",
|
309 |
+
col_names,
|
310 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
311 |
+
if config_metadata is not None
|
312 |
+
else 0,
|
313 |
+
)
|
314 |
+
col_mapping[text1_col] = "text1"
|
315 |
+
col_mapping[text2_col] = "text2"
|
316 |
+
col_mapping[target_col] = "target"
|
317 |
+
|
318 |
elif selected_task == "entity_extraction":
|
319 |
with col1:
|
320 |
st.markdown("`tokens` column")
|
|
|
327 |
tokens_col = st.selectbox(
|
328 |
"This column should contain the array of tokens to be classified",
|
329 |
col_names,
|
330 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tokens"))
|
331 |
+
if config_metadata is not None
|
332 |
+
else 0,
|
333 |
)
|
334 |
tags_col = st.selectbox(
|
335 |
"This column should contain the labels associated with each part of the text",
|
336 |
col_names,
|
337 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tags"))
|
338 |
+
if config_metadata is not None
|
339 |
+
else 0,
|
340 |
)
|
341 |
col_mapping[tokens_col] = "tokens"
|
342 |
col_mapping[tags_col] = "tags"
|
|
|
353 |
text_col = st.selectbox(
|
354 |
"This column should contain the text to be translated",
|
355 |
col_names,
|
356 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "source"))
|
357 |
+
if config_metadata is not None
|
358 |
+
else 0,
|
359 |
)
|
360 |
target_col = st.selectbox(
|
361 |
"This column should contain the target translation",
|
362 |
col_names,
|
363 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
364 |
+
if config_metadata is not None
|
365 |
+
else 0,
|
366 |
)
|
367 |
col_mapping[text_col] = "source"
|
368 |
col_mapping[target_col] = "target"
|
|
|
379 |
text_col = st.selectbox(
|
380 |
"This column should contain the text to be summarized",
|
381 |
col_names,
|
382 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
383 |
+
if config_metadata is not None
|
384 |
+
else 0,
|
385 |
)
|
386 |
target_col = st.selectbox(
|
387 |
"This column should contain the target summary",
|
388 |
col_names,
|
389 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
390 |
+
if config_metadata is not None
|
391 |
+
else 0,
|
392 |
)
|
393 |
col_mapping[text_col] = "text"
|
394 |
col_mapping[target_col] = "target"
|
395 |
|
396 |
elif selected_task == "extractive_question_answering":
|
397 |
+
if config_metadata is not None:
|
398 |
+
col_mapping = config_metadata["col_mapping"]
|
399 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
400 |
col_mapping = format_col_mapping(col_mapping)
|
401 |
with col1:
|
|
|
419 |
context_col = st.selectbox(
|
420 |
"This column should contain the question's context",
|
421 |
col_names,
|
422 |
+
index=col_names.index(get_key(col_mapping, "context")) if config_metadata is not None else 0,
|
423 |
)
|
424 |
question_col = st.selectbox(
|
425 |
"This column should contain the question to be answered, given the context",
|
426 |
col_names,
|
427 |
+
index=col_names.index(get_key(col_mapping, "question")) if config_metadata is not None else 0,
|
428 |
)
|
429 |
answers_text_col = st.selectbox(
|
430 |
"This column should contain example answers to the question, extracted from the context",
|
431 |
col_names,
|
432 |
+
index=col_names.index(get_key(col_mapping, "answers.text")) if config_metadata is not None else 0,
|
433 |
)
|
434 |
answers_start_col = st.selectbox(
|
435 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
436 |
col_names,
|
437 |
+
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
438 |
+
if config_metadata is not None
|
439 |
+
else 0,
|
440 |
)
|
441 |
col_mapping[context_col] = "context"
|
442 |
col_mapping[question_col] = "question"
|
|
|
454 |
image_col = st.selectbox(
|
455 |
"This column should contain the images to be classified",
|
456 |
col_names,
|
457 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "image"))
|
458 |
+
if config_metadata is not None
|
459 |
+
else 0,
|
460 |
)
|
461 |
target_col = st.selectbox(
|
462 |
"This column should contain the labels associated with the images",
|
463 |
col_names,
|
464 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
465 |
+
if config_metadata is not None
|
466 |
+
else 0,
|
467 |
)
|
468 |
col_mapping[image_col] = "image"
|
469 |
col_mapping[target_col] = "target"
|
utils.py
CHANGED
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
13 |
"binary_classification": "text-classification",
|
14 |
"multi_class_classification": "text-classification",
|
|
|
15 |
"entity_extraction": "token-classification",
|
16 |
"extractive_question_answering": "question-answering",
|
17 |
"translation": "translation",
|
@@ -197,3 +198,14 @@ def create_autotrain_project_name(dataset_id: str) -> str:
|
|
197 |
# Project names need to be unique, so we append a random string to guarantee this
|
198 |
project_id = str(uuid.uuid4())[:8]
|
199 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
13 |
"binary_classification": "text-classification",
|
14 |
"multi_class_classification": "text-classification",
|
15 |
+
"natural_language_inference": "text-classification",
|
16 |
"entity_extraction": "token-classification",
|
17 |
"extractive_question_answering": "question-answering",
|
18 |
"translation": "translation",
|
|
|
198 |
# Project names need to be unique, so we append a random string to guarantee this
|
199 |
project_id = str(uuid.uuid4())[:8]
|
200 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
201 |
+
|
202 |
+
|
203 |
+
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|
204 |
+
"""Gets the dataset card metadata for the given config."""
|
205 |
+
if metadata is None:
|
206 |
+
return None
|
207 |
+
config_metadata = [m for m in metadata if m["config"] == config]
|
208 |
+
if len(config_metadata) >= 1:
|
209 |
+
return config_metadata[0]
|
210 |
+
else:
|
211 |
+
return None
|