Spaces:
Running
Running
ZeroCommand
commited on
Commit
•
2569404
1
Parent(s):
8e32a09
fix feature mapping by adding multi labels case
Browse files
text_classification.py
CHANGED
@@ -22,7 +22,12 @@ class HuggingFaceInferenceAPIResponse:
|
|
22 |
def get_labels_and_features_from_dataset(ds):
|
23 |
try:
|
24 |
dataset_features = ds.features
|
25 |
-
label_keys = [i for i in dataset_features.keys() if i
|
|
|
|
|
|
|
|
|
|
|
26 |
if len(label_keys) == 0: # no labels found
|
27 |
# return everything for post processing
|
28 |
return list(dataset_features.keys()), list(dataset_features.keys())
|
@@ -32,7 +37,6 @@ def get_labels_and_features_from_dataset(ds):
|
|
32 |
labels = label_feat.names
|
33 |
else:
|
34 |
labels = dataset_features[label_keys[0]].names
|
35 |
-
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
36 |
return labels, features
|
37 |
except Exception as e:
|
38 |
logging.warning(
|
|
|
22 |
def get_labels_and_features_from_dataset(ds):
|
23 |
try:
|
24 |
dataset_features = ds.features
|
25 |
+
label_keys = [i for i in dataset_features.keys() if i == 'label']
|
26 |
+
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
27 |
+
if len(label_keys) == 0: # no labels found
|
28 |
+
label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
|
29 |
+
features += label_keys
|
30 |
+
|
31 |
if len(label_keys) == 0: # no labels found
|
32 |
# return everything for post processing
|
33 |
return list(dataset_features.keys()), list(dataset_features.keys())
|
|
|
37 |
labels = label_feat.names
|
38 |
else:
|
39 |
labels = dataset_features[label_keys[0]].names
|
|
|
40 |
return labels, features
|
41 |
except Exception as e:
|
42 |
logging.warning(
|
text_classification_ui_helpers.py
CHANGED
@@ -138,7 +138,7 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
|
|
138 |
ds_labels = list(shared_labels)
|
139 |
if len(ds_labels) > MAX_LABELS:
|
140 |
ds_labels = ds_labels[:MAX_LABELS]
|
141 |
-
gr.Warning(f"
|
142 |
|
143 |
# sort labels to make sure the order is consistent
|
144 |
# prediction gives the order based on probability
|
@@ -393,10 +393,12 @@ def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, da
|
|
393 |
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
394 |
label_mapping = {}
|
395 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
396 |
-
logger.warn("Label mapping corrupted:
|
|
|
397 |
|
398 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
399 |
-
logger.warn("Feature mapping corrupted:
|
|
|
400 |
|
401 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|
402 |
# align the saved labels with dataset labels order
|
@@ -405,7 +407,10 @@ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
|
405 |
if "features" not in all_mappings.keys():
|
406 |
logger.warning("features not in all_mappings")
|
407 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
|
|
|
|
408 |
feature_mapping = all_mappings["features"]
|
|
|
409 |
return label_mapping, feature_mapping
|
410 |
|
411 |
def show_hf_token_info(token):
|
|
|
138 |
ds_labels = list(shared_labels)
|
139 |
if len(ds_labels) > MAX_LABELS:
|
140 |
ds_labels = ds_labels[:MAX_LABELS]
|
141 |
+
gr.Warning(f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd.")
|
142 |
|
143 |
# sort labels to make sure the order is consistent
|
144 |
# prediction gives the order based on probability
|
|
|
393 |
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
394 |
label_mapping = {}
|
395 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
396 |
+
logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
397 |
+
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
|
398 |
|
399 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
400 |
+
logger.warn(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
401 |
+
\nall_mappings: {all_mappings}\nds_features: {ds_features}""")
|
402 |
|
403 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|
404 |
# align the saved labels with dataset labels order
|
|
|
407 |
if "features" not in all_mappings.keys():
|
408 |
logger.warning("features not in all_mappings")
|
409 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
410 |
+
|
411 |
+
special_classlabel = [feature for feature in ds_features if feature.startswith("label")]
|
412 |
feature_mapping = all_mappings["features"]
|
413 |
+
feature_mapping.update({"label": special_classlabel[0] if special_classlabel else "label"})
|
414 |
return label_mapping, feature_mapping
|
415 |
|
416 |
def show_hf_token_info(token):
|