Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
536b2a2
1
Parent(s):
9037bf7
add pre-check for column mapping values
Browse files- app.py +17 -9
- text_classification.py +11 -2
app.py
CHANGED
@@ -10,7 +10,7 @@ import json
|
|
10 |
|
11 |
from transformers.pipelines import TextClassificationPipeline
|
12 |
|
13 |
-
from text_classification import text_classification_fix_column_mapping
|
14 |
|
15 |
|
16 |
HF_REPO_ID = 'HF_REPO_ID'
|
@@ -233,15 +233,23 @@ with gr.Blocks(theme=theme) as iface:
|
|
233 |
column_mapping = '{}'
|
234 |
if id2label_mapping_dataframe is not None:
|
235 |
column_mapping = id2label_mapping_dataframe.to_json(orient="split")
|
236 |
-
|
237 |
-
|
238 |
-
return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
|
239 |
-
else:
|
240 |
return (gr.update(interactive=False),
|
241 |
-
gr.update(
|
242 |
-
gr.update(
|
243 |
-
gr.update(
|
244 |
-
gr.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
with gr.Row():
|
246 |
gr.Markdown('''
|
247 |
<h1 style="text-align: center;">
|
|
|
10 |
|
11 |
from transformers.pipelines import TextClassificationPipeline
|
12 |
|
13 |
+
from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
|
14 |
|
15 |
|
16 |
HF_REPO_ID = 'HF_REPO_ID'
|
|
|
233 |
column_mapping = '{}'
|
234 |
if id2label_mapping_dataframe is not None:
|
235 |
column_mapping = id2label_mapping_dataframe.to_json(orient="split")
|
236 |
+
if check_column_mapping_keys_validity(column_mapping) is False:
|
237 |
+
gr.Warning('Label mapping table has invalid contents. Please check again.')
|
|
|
|
|
238 |
return (gr.update(interactive=False),
|
239 |
+
gr.update(),
|
240 |
+
gr.update(),
|
241 |
+
gr.update(),
|
242 |
+
gr.update(),
|
243 |
+
gr.update())
|
244 |
+
else:
|
245 |
+
if model_id and dataset_id and dataset_config and dataset_split:
|
246 |
+
return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
|
247 |
+
else:
|
248 |
+
return (gr.update(interactive=False),
|
249 |
+
gr.update(visible=True),
|
250 |
+
gr.update(visible=False),
|
251 |
+
gr.update(visible=False),
|
252 |
+
gr.update(visible=False))
|
253 |
with gr.Row():
|
254 |
gr.Markdown('''
|
255 |
<h1 style="text-align: center;">
|
text_classification.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import datasets
|
2 |
-
|
3 |
import logging
|
4 |
-
|
5 |
import pandas as pd
|
6 |
|
7 |
|
@@ -35,6 +34,16 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
|
|
35 |
|
36 |
return id2label_mapping, dataset_labels
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
|
40 |
# We assume dataset is ok here
|
|
|
1 |
import datasets
|
|
|
2 |
import logging
|
3 |
+
import json
|
4 |
import pandas as pd
|
5 |
|
6 |
|
|
|
34 |
|
35 |
return id2label_mapping, dataset_labels
|
36 |
|
37 |
+
def check_column_mapping_keys_validity(column_mapping):
|
38 |
+
# get the element in all the list elements
|
39 |
+
column_mapping = json.loads(column_mapping)
|
40 |
+
if "data" not in column_mapping.keys():
|
41 |
+
return True
|
42 |
+
user_labels = set([pair[0] for pair in column_mapping["data"]])
|
43 |
+
model_labels = set([pair[1] for pair in column_mapping["data"]])
|
44 |
+
|
45 |
+
return user_labels == model_labels
|
46 |
+
|
47 |
|
48 |
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
|
49 |
# We assume dataset is ok here
|