Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
9f91529
1
Parent(s):
2569404
return label keys and add in mapping before submission
Browse files
text_classification.py
CHANGED
@@ -26,23 +26,22 @@ def get_labels_and_features_from_dataset(ds):
|
|
26 |
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
27 |
if len(label_keys) == 0: # no labels found
|
28 |
label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
|
29 |
-
features += label_keys
|
30 |
|
31 |
if len(label_keys) == 0: # no labels found
|
32 |
# return everything for post processing
|
33 |
-
return list(dataset_features.keys()), list(dataset_features.keys())
|
34 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
35 |
if hasattr(dataset_features[label_keys[0]], 'feature'):
|
36 |
label_feat = dataset_features[label_keys[0]].feature
|
37 |
labels = label_feat.names
|
38 |
else:
|
39 |
labels = dataset_features[label_keys[0]].names
|
40 |
-
return labels, features
|
41 |
except Exception as e:
|
42 |
logging.warning(
|
43 |
f"Get Labels/Features Failed for dataset: {e}"
|
44 |
)
|
45 |
-
return None, None
|
46 |
|
47 |
def check_model_task(model_id):
|
48 |
# check if model is valid on huggingface
|
|
|
26 |
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
27 |
if len(label_keys) == 0: # no labels found
|
28 |
label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
|
|
|
29 |
|
30 |
if len(label_keys) == 0: # no labels found
|
31 |
# return everything for post processing
|
32 |
+
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
33 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
34 |
if hasattr(dataset_features[label_keys[0]], 'feature'):
|
35 |
label_feat = dataset_features[label_keys[0]].feature
|
36 |
labels = label_feat.names
|
37 |
else:
|
38 |
labels = dataset_features[label_keys[0]].names
|
39 |
+
return labels, features, label_keys
|
40 |
except Exception as e:
|
41 |
logging.warning(
|
42 |
f"Get Labels/Features Failed for dataset: {e}"
|
43 |
)
|
44 |
+
return None, None, None
|
45 |
|
46 |
def check_model_task(model_id):
|
47 |
# check if model is valid on huggingface
|
text_classification_ui_helpers.py
CHANGED
@@ -198,7 +198,7 @@ def precheck_model_ds_enable_example_btn(
|
|
198 |
try:
|
199 |
ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
|
200 |
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
|
201 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])
|
202 |
|
203 |
if model_task is None or model_task != "text-classification":
|
204 |
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
|
@@ -300,7 +300,7 @@ def align_columns_and_show_prediction(
|
|
300 |
model_labels = list(prediction_response.keys())
|
301 |
|
302 |
ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
|
303 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
|
304 |
|
305 |
# when dataset does not have labels or features
|
306 |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
|
@@ -390,7 +390,7 @@ def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, da
|
|
390 |
return gr.update(interactive=False)
|
391 |
return gr.update(interactive=True)
|
392 |
|
393 |
-
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
394 |
label_mapping = {}
|
395 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
396 |
logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
@@ -408,9 +408,9 @@ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
|
408 |
logger.warning("features not in all_mappings")
|
409 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
410 |
|
411 |
-
special_classlabel = [feature for feature in ds_features if feature.startswith("label")]
|
412 |
feature_mapping = all_mappings["features"]
|
413 |
-
|
|
|
414 |
return label_mapping, feature_mapping
|
415 |
|
416 |
def show_hf_token_info(token):
|
@@ -426,8 +426,8 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
|
426 |
|
427 |
# get ds labels and features again for alignment
|
428 |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
|
429 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
|
430 |
-
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features)
|
431 |
|
432 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
433 |
save_job_to_pipe(
|
|
|
198 |
try:
|
199 |
ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
|
200 |
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
|
201 |
+
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds[dataset_split])
|
202 |
|
203 |
if model_task is None or model_task != "text-classification":
|
204 |
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
|
|
|
300 |
model_labels = list(prediction_response.keys())
|
301 |
|
302 |
ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
|
303 |
+
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)
|
304 |
|
305 |
# when dataset does not have labels or features
|
306 |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
|
|
|
390 |
return gr.update(interactive=False)
|
391 |
return gr.update(interactive=True)
|
392 |
|
393 |
+
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
|
394 |
label_mapping = {}
|
395 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
396 |
logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
|
|
408 |
logger.warning("features not in all_mappings")
|
409 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
410 |
|
|
|
411 |
feature_mapping = all_mappings["features"]
|
412 |
+
if len(label_keys) > 0:
|
413 |
+
feature_mapping.update({"label": label_keys[0]})
|
414 |
return label_mapping, feature_mapping
|
415 |
|
416 |
def show_hf_token_info(token):
|
|
|
426 |
|
427 |
# get ds labels and features again for alignment
|
428 |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
|
429 |
+
ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
|
430 |
+
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys)
|
431 |
|
432 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
433 |
save_job_to_pipe(
|