lewtun HF staff commited on
Commit
533bc81
1 Parent(s): 6348497

Add NLI support

Browse files
Files changed (2) hide show
  1. app.py +53 -2
  2. utils.py +1 -0
app.py CHANGED
@@ -36,6 +36,7 @@ TASK_TO_ID = {
36
  "image_multi_class_classification": 18,
37
  "binary_classification": 1,
38
  "multi_class_classification": 2,
 
39
  "entity_extraction": 4,
40
  "extractive_question_answering": 5,
41
  "translation": 6,
@@ -50,6 +51,7 @@ TASK_TO_DEFAULT_METRICS = {
50
  "recall",
51
  "accuracy",
52
  ],
 
53
  "entity_extraction": ["precision", "recall", "f1", "accuracy"],
54
  "extractive_question_answering": ["f1", "exact_match"],
55
  "translation": ["sacrebleu"],
@@ -117,11 +119,19 @@ SUPPORTED_METRICS = [
117
  "jordyvl/ece",
118
  "lvwerra/ai4code",
119
  "lvwerra/amex",
120
- "lvwerra/test",
121
- "lvwerra/test_metric",
122
  ]
123
 
124
 
 
 
 
 
 
 
 
 
 
 
125
  #######
126
  # APP #
127
  #######
@@ -269,6 +279,47 @@ with st.expander("Advanced configuration"):
269
  col_mapping[text_col] = "text"
270
  col_mapping[target_col] = "target"
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  elif selected_task == "entity_extraction":
273
  with col1:
274
  st.markdown("`tokens` column")
 
36
  "image_multi_class_classification": 18,
37
  "binary_classification": 1,
38
  "multi_class_classification": 2,
39
+ "natural_language_inference": 22,
40
  "entity_extraction": 4,
41
  "extractive_question_answering": 5,
42
  "translation": 6,
 
51
  "recall",
52
  "accuracy",
53
  ],
54
+ "natural_language_inference": ["f1", "precision", "recall", "auc", "accuracy"],
55
  "entity_extraction": ["precision", "recall", "f1", "accuracy"],
56
  "extractive_question_answering": ["f1", "exact_match"],
57
  "translation": ["sacrebleu"],
 
119
  "jordyvl/ece",
120
  "lvwerra/ai4code",
121
  "lvwerra/amex",
 
 
122
  ]
123
 
124
 
125
+ def get_config_metadata(config, metadata=None):
126
+ if metadata is None:
127
+ return None
128
+ config_metadata = [m for m in metadata if m["config"] == config]
129
+ if len(config_metadata) == 1:
130
+ return config_metadata[0]
131
+ else:
132
+ return None
133
+
134
+
135
  #######
136
  # APP #
137
  #######
 
279
  col_mapping[text_col] = "text"
280
  col_mapping[target_col] = "target"
281
 
282
+ col_mapping = {}
283
+ if selected_task in ["natural_language_inference"]:
284
+ config_metadata = get_config_metadata(selected_config, metadata)
285
+ with col1:
286
+ st.markdown("`text1` column")
287
+ st.text("")
288
+ st.text("")
289
+ st.text("")
290
+ st.text("")
291
+ st.markdown("`text2` column")
292
+ st.text("")
293
+ st.text("")
294
+ st.text("")
295
+ st.text("")
296
+ st.markdown("`target` column")
297
+ with col2:
298
+ text1_col = st.selectbox(
299
+ "This column should contain the first text passage to be classified",
300
+ col_names,
301
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text1"))
302
+ if config_metadata is not None
303
+ else 0,
304
+ )
305
+ text2_col = st.selectbox(
306
+ "This column should contain the second text passage to be classified",
307
+ col_names,
308
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text2"))
309
+ if config_metadata is not None
310
+ else 0,
311
+ )
312
+ target_col = st.selectbox(
313
+ "This column should contain the labels associated with the text",
314
+ col_names,
315
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
316
+ if config_metadata is not None
317
+ else 0,
318
+ )
319
+ col_mapping[text1_col] = "text1"
320
+ col_mapping[text2_col] = "text2"
321
+ col_mapping[target_col] = "target"
322
+
323
  elif selected_task == "entity_extraction":
324
  with col1:
325
  st.markdown("`tokens` column")
utils.py CHANGED
@@ -12,6 +12,7 @@ from tqdm import tqdm
12
  AUTOTRAIN_TASK_TO_HUB_TASK = {
13
  "binary_classification": "text-classification",
14
  "multi_class_classification": "text-classification",
 
15
  "entity_extraction": "token-classification",
16
  "extractive_question_answering": "question-answering",
17
  "translation": "translation",
 
12
  AUTOTRAIN_TASK_TO_HUB_TASK = {
13
  "binary_classification": "text-classification",
14
  "multi_class_classification": "text-classification",
15
+ "natural_language_inference": "text-classification",
16
  "entity_extraction": "token-classification",
17
  "extractive_question_answering": "question-answering",
18
  "translation": "translation",