Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
ac0eaff
1
Parent(s):
536b2a2
polish up and add more information
Browse files- app.py +11 -12
- text_classification.py +9 -4
app.py
CHANGED
@@ -59,11 +59,10 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
|
|
59 |
return dataset_id, None, None
|
60 |
return dataset_id, dataset_config, dataset_split
|
61 |
|
62 |
-
def try_validate(
|
63 |
# Validate model
|
64 |
-
m_id, ppl = check_model(model_id=model_id)
|
65 |
if m_id is None:
|
66 |
-
gr.Warning(
|
67 |
return (
|
68 |
gr.update(interactive=False), # Submit button
|
69 |
gr.update(visible=True), # Loading row
|
@@ -73,7 +72,7 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
|
|
73 |
gr.update(visible=False), # Label mapping preview
|
74 |
)
|
75 |
if isinstance(ppl, Exception):
|
76 |
-
gr.Warning(f'Failed to load
|
77 |
return (
|
78 |
gr.update(interactive=False), # Submit button
|
79 |
gr.update(visible=True), # Loading row
|
@@ -124,8 +123,6 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
|
|
124 |
|
125 |
column_mapping = json.dumps(column_mapping, indent=2)
|
126 |
|
127 |
-
del ppl
|
128 |
-
|
129 |
if prediction_result is None:
|
130 |
gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
|
131 |
return (
|
@@ -212,7 +209,6 @@ with gr.Blocks(theme=theme) as iface:
|
|
212 |
def check_dataset_and_get_config(dataset_id):
|
213 |
try:
|
214 |
configs = datasets.get_dataset_config_names(dataset_id)
|
215 |
-
print(configs)
|
216 |
return gr.Dropdown(configs, value=configs[0], visible=True)
|
217 |
except Exception:
|
218 |
# Dataset may not exist
|
@@ -221,19 +217,19 @@ with gr.Blocks(theme=theme) as iface:
|
|
221 |
def check_dataset_and_get_split(dataset_config, dataset_id):
|
222 |
try:
|
223 |
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
|
224 |
-
print('splits: ',splits)
|
225 |
return gr.Dropdown(splits, value=splits[0], visible=True)
|
226 |
except Exception as e:
|
227 |
# Dataset may not exist
|
228 |
-
|
229 |
pass
|
230 |
|
231 |
def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
|
232 |
-
print('model_id: ',model_id)
|
233 |
column_mapping = '{}'
|
|
|
|
|
234 |
if id2label_mapping_dataframe is not None:
|
235 |
column_mapping = id2label_mapping_dataframe.to_json(orient="split")
|
236 |
-
if check_column_mapping_keys_validity(column_mapping) is False:
|
237 |
gr.Warning('Label mapping table has invalid contents. Please check again.')
|
238 |
return (gr.update(interactive=False),
|
239 |
gr.update(),
|
@@ -243,12 +239,15 @@ with gr.Blocks(theme=theme) as iface:
|
|
243 |
gr.update())
|
244 |
else:
|
245 |
if model_id and dataset_id and dataset_config and dataset_split:
|
246 |
-
return try_validate(
|
247 |
else:
|
|
|
|
|
248 |
return (gr.update(interactive=False),
|
249 |
gr.update(visible=True),
|
250 |
gr.update(visible=False),
|
251 |
gr.update(visible=False),
|
|
|
252 |
gr.update(visible=False))
|
253 |
with gr.Row():
|
254 |
gr.Markdown('''
|
|
|
59 |
return dataset_id, None, None
|
60 |
return dataset_id, dataset_config, dataset_split
|
61 |
|
62 |
+
def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping='{}'):
|
63 |
# Validate model
|
|
|
64 |
if m_id is None:
|
65 |
+
gr.Warning('Model is not accessible. Please set your HF_TOKEN if it is a private model.')
|
66 |
return (
|
67 |
gr.update(interactive=False), # Submit button
|
68 |
gr.update(visible=True), # Loading row
|
|
|
72 |
gr.update(visible=False), # Label mapping preview
|
73 |
)
|
74 |
if isinstance(ppl, Exception):
|
75 |
+
gr.Warning(f'Failed to load model": {ppl}')
|
76 |
return (
|
77 |
gr.update(interactive=False), # Submit button
|
78 |
gr.update(visible=True), # Loading row
|
|
|
123 |
|
124 |
column_mapping = json.dumps(column_mapping, indent=2)
|
125 |
|
|
|
|
|
126 |
if prediction_result is None:
|
127 |
gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
|
128 |
return (
|
|
|
209 |
def check_dataset_and_get_config(dataset_id):
|
210 |
try:
|
211 |
configs = datasets.get_dataset_config_names(dataset_id)
|
|
|
212 |
return gr.Dropdown(configs, value=configs[0], visible=True)
|
213 |
except Exception:
|
214 |
# Dataset may not exist
|
|
|
217 |
def check_dataset_and_get_split(dataset_config, dataset_id):
|
218 |
try:
|
219 |
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
|
|
|
220 |
return gr.Dropdown(splits, value=splits[0], visible=True)
|
221 |
except Exception as e:
|
222 |
# Dataset may not exist
|
223 |
+
gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
|
224 |
pass
|
225 |
|
226 |
def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
|
|
|
227 |
column_mapping = '{}'
|
228 |
+
m_id, ppl = check_model(model_id=model_id)
|
229 |
+
|
230 |
if id2label_mapping_dataframe is not None:
|
231 |
column_mapping = id2label_mapping_dataframe.to_json(orient="split")
|
232 |
+
if check_column_mapping_keys_validity(column_mapping, ppl) is False:
|
233 |
gr.Warning('Label mapping table has invalid contents. Please check again.')
|
234 |
return (gr.update(interactive=False),
|
235 |
gr.update(),
|
|
|
239 |
gr.update())
|
240 |
else:
|
241 |
if model_id and dataset_id and dataset_config and dataset_split:
|
242 |
+
return try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping)
|
243 |
else:
|
244 |
+
del ppl
|
245 |
+
|
246 |
return (gr.update(interactive=False),
|
247 |
gr.update(visible=True),
|
248 |
gr.update(visible=False),
|
249 |
gr.update(visible=False),
|
250 |
+
gr.update(visible=False),
|
251 |
gr.update(visible=False))
|
252 |
with gr.Row():
|
253 |
gr.Markdown('''
|
text_classification.py
CHANGED
@@ -34,15 +34,19 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
|
|
34 |
|
35 |
return id2label_mapping, dataset_labels
|
36 |
|
37 |
-
|
|
|
38 |
# get the element in all the list elements
|
39 |
column_mapping = json.loads(column_mapping)
|
40 |
if "data" not in column_mapping.keys():
|
41 |
return True
|
42 |
user_labels = set([pair[0] for pair in column_mapping["data"]])
|
43 |
model_labels = set([pair[1] for pair in column_mapping["data"]])
|
|
|
|
|
|
|
44 |
|
45 |
-
return user_labels == model_labels
|
46 |
|
47 |
|
48 |
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
|
@@ -100,7 +104,6 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
|
|
100 |
if isinstance(column_mapping["data"], list):
|
101 |
# Use the column mapping passed by user
|
102 |
for user_label, model_label in column_mapping["data"]:
|
103 |
-
print(user_label, model_label)
|
104 |
id2label_mapping[model_label] = user_label
|
105 |
elif None in id2label_mapping.values():
|
106 |
column_mapping["label"] = {
|
@@ -108,7 +111,9 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
|
|
108 |
}
|
109 |
return column_mapping, prediction_result, None
|
110 |
|
111 |
-
|
|
|
|
|
112 |
id2label_df = pd.DataFrame({
|
113 |
"Dataset Labels": dataset_labels,
|
114 |
"Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],
|
|
|
34 |
|
35 |
return id2label_mapping, dataset_labels
|
36 |
|
37 |
+
|
38 |
+
def check_column_mapping_keys_validity(column_mapping, ppl):
|
39 |
# get the element in all the list elements
|
40 |
column_mapping = json.loads(column_mapping)
|
41 |
if "data" not in column_mapping.keys():
|
42 |
return True
|
43 |
user_labels = set([pair[0] for pair in column_mapping["data"]])
|
44 |
model_labels = set([pair[1] for pair in column_mapping["data"]])
|
45 |
+
|
46 |
+
id2label = ppl.model.config.id2label
|
47 |
+
original_labels = set(id2label.values())
|
48 |
|
49 |
+
return user_labels == model_labels == original_labels
|
50 |
|
51 |
|
52 |
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
|
|
|
104 |
if isinstance(column_mapping["data"], list):
|
105 |
# Use the column mapping passed by user
|
106 |
for user_label, model_label in column_mapping["data"]:
|
|
|
107 |
id2label_mapping[model_label] = user_label
|
108 |
elif None in id2label_mapping.values():
|
109 |
column_mapping["label"] = {
|
|
|
111 |
}
|
112 |
return column_mapping, prediction_result, None
|
113 |
|
114 |
+
prediction_result = {
|
115 |
+
f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results
|
116 |
+
}
|
117 |
id2label_df = pd.DataFrame({
|
118 |
"Dataset Labels": dataset_labels,
|
119 |
"Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],
|