Spaces:
Running
Running
Commit
•
7487fdb
1
Parent(s):
8217e92
bug-fix-label-mapping-align-with-correct-idx (#80)
Browse files- fix the label mapping order; fix out of scope error (bc6f52dabdda8f4e5aa9c9d980ffd3d2c8a55c49)
Co-authored-by: zcy <ZeroCommand@users.noreply.huggingface.co>
- app.py +1 -3
- app_leaderboard.py +6 -1
- text_classification_ui_helpers.py +17 -9
app.py
CHANGED
@@ -12,12 +12,10 @@ try:
|
|
12 |
with gr.Tab("Text Classification"):
|
13 |
get_demo_text_classification()
|
14 |
with gr.Tab("Leaderboard") as leaderboard_tab:
|
15 |
-
get_demo_leaderboard()
|
16 |
with gr.Tab("Logs(Debug)"):
|
17 |
get_demo_debug()
|
18 |
|
19 |
-
leaderboard_tab.select(fn=get_demo_leaderboard)
|
20 |
-
|
21 |
start_process_run_job()
|
22 |
|
23 |
demo.queue(max_size=1000)
|
|
|
12 |
with gr.Tab("Text Classification"):
|
13 |
get_demo_text_classification()
|
14 |
with gr.Tab("Leaderboard") as leaderboard_tab:
|
15 |
+
get_demo_leaderboard(leaderboard_tab)
|
16 |
with gr.Tab("Logs(Debug)"):
|
17 |
get_demo_debug()
|
18 |
|
|
|
|
|
19 |
start_process_run_job()
|
20 |
|
21 |
demo.queue(max_size=1000)
|
app_leaderboard.py
CHANGED
@@ -73,8 +73,11 @@ def get_display_df(df):
|
|
73 |
)
|
74 |
return display_df
|
75 |
|
|
|
|
|
|
|
76 |
|
77 |
-
def get_demo():
|
78 |
logger.info("Loading leaderboard records")
|
79 |
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
|
80 |
records = leaderboard.records
|
@@ -116,6 +119,8 @@ def get_demo():
|
|
116 |
with gr.Row():
|
117 |
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
|
118 |
|
|
|
|
|
119 |
@gr.on(
|
120 |
triggers=[
|
121 |
model_select.change,
|
|
|
73 |
)
|
74 |
return display_df
|
75 |
|
76 |
+
def update_leaderboard_records():
|
77 |
+
logger.info("Updating leaderboard records")
|
78 |
+
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
|
79 |
|
80 |
+
def get_demo(leaderboard_tab):
|
81 |
logger.info("Loading leaderboard records")
|
82 |
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
|
83 |
records = leaderboard.records
|
|
|
119 |
with gr.Row():
|
120 |
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
|
121 |
|
122 |
+
leaderboard_tab.select(fn=update_leaderboard_records)
|
123 |
+
|
124 |
@gr.on(
|
125 |
triggers=[
|
126 |
model_select.change,
|
text_classification_ui_helpers.py
CHANGED
@@ -30,7 +30,6 @@ MAX_FEATURES = 20
|
|
30 |
ds_dict = None
|
31 |
ds_config = None
|
32 |
|
33 |
-
|
34 |
def get_related_datasets_from_leaderboard(model_id):
|
35 |
records = leaderboard.records
|
36 |
model_records = records[records["model_id"] == model_id]
|
@@ -100,7 +99,7 @@ def export_mappings(all_mappings, key, subkeys, values):
|
|
100 |
if subkeys is None:
|
101 |
subkeys = list(all_mappings[key].keys())
|
102 |
|
103 |
-
if not subkeys:
|
104 |
logging.debug(f"subkeys is empty for {key}")
|
105 |
return all_mappings
|
106 |
|
@@ -121,6 +120,8 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
|
|
121 |
ds_labels = ds_labels[:MAX_LABELS]
|
122 |
gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
|
123 |
|
|
|
|
|
124 |
ds_labels.sort()
|
125 |
model_labels.sort()
|
126 |
|
@@ -293,17 +294,20 @@ def check_column_mapping_keys_validity(all_mappings):
|
|
293 |
return (gr.update(interactive=True), gr.update(visible=False))
|
294 |
|
295 |
|
296 |
-
def construct_label_and_feature_mapping(all_mappings):
|
297 |
label_mapping = {}
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
|
|
|
|
|
|
|
|
302 |
label_mapping.update({str(i): all_mappings["labels"][label]})
|
303 |
|
304 |
if "features" not in all_mappings.keys():
|
305 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
306 |
-
return (gr.update(interactive=True), gr.update(visible=False))
|
307 |
feature_mapping = all_mappings["features"]
|
308 |
return label_mapping, feature_mapping
|
309 |
|
@@ -311,7 +315,11 @@ def construct_label_and_feature_mapping(all_mappings):
|
|
311 |
def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
312 |
all_mappings = read_column_mapping(uid)
|
313 |
check_column_mapping_keys_validity(all_mappings)
|
314 |
-
|
|
|
|
|
|
|
|
|
315 |
|
316 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
317 |
save_job_to_pipe(
|
|
|
30 |
ds_dict = None
|
31 |
ds_config = None
|
32 |
|
|
|
33 |
def get_related_datasets_from_leaderboard(model_id):
|
34 |
records = leaderboard.records
|
35 |
model_records = records[records["model_id"] == model_id]
|
|
|
99 |
if subkeys is None:
|
100 |
subkeys = list(all_mappings[key].keys())
|
101 |
|
102 |
+
if not subkeys:
|
103 |
logging.debug(f"subkeys is empty for {key}")
|
104 |
return all_mappings
|
105 |
|
|
|
120 |
ds_labels = ds_labels[:MAX_LABELS]
|
121 |
gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
|
122 |
|
123 |
+
# sort labels to make sure the order is consistent
|
124 |
+
# prediction gives the order based on probability
|
125 |
ds_labels.sort()
|
126 |
model_labels.sort()
|
127 |
|
|
|
294 |
return (gr.update(interactive=True), gr.update(visible=False))
|
295 |
|
296 |
|
297 |
+
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
298 |
label_mapping = {}
|
299 |
+
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
300 |
+
gr.Warning("Label mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
301 |
+
|
302 |
+
if len(all_mappings["features"].keys()) != len(ds_features):
|
303 |
+
gr.Warning("Feature mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
304 |
+
|
305 |
+
for i, label in zip(range(len(ds_labels)), ds_labels):
|
306 |
+
# align the saved labels with dataset labels order
|
307 |
label_mapping.update({str(i): all_mappings["labels"][label]})
|
308 |
|
309 |
if "features" not in all_mappings.keys():
|
310 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
|
|
311 |
feature_mapping = all_mappings["features"]
|
312 |
return label_mapping, feature_mapping
|
313 |
|
|
|
315 |
def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
316 |
all_mappings = read_column_mapping(uid)
|
317 |
check_column_mapping_keys_validity(all_mappings)
|
318 |
+
|
319 |
+
# get ds labels and features again for alignment
|
320 |
+
ds = datasets.load_dataset(d_id, config)[split]
|
321 |
+
ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
|
322 |
+
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features)
|
323 |
|
324 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
325 |
save_job_to_pipe(
|