Spaces:
Running
Running
import datasets | |
import logging | |
import json | |
import pandas as pd | |
def text_classificaiton_match_label_case_unsensative(id2label_mapping, label): | |
for model_label in id2label_mapping.keys(): | |
if model_label.upper() == label.upper(): | |
return model_label, label | |
return None, label | |
def text_classification_map_model_and_dataset_labels(id2label, dataset_features): | |
id2label_mapping = {id2label[k]: None for k in id2label.keys()} | |
dataset_labels = None | |
for feature in dataset_features.values(): | |
if not isinstance(feature, datasets.ClassLabel): | |
continue | |
if len(feature.names) != len(id2label_mapping.keys()): | |
continue | |
dataset_labels = feature.names | |
# Try to match labels | |
for label in feature.names: | |
if label in id2label_mapping.keys(): | |
model_label = label | |
else: | |
# Try to find case unsensative | |
model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label) | |
if model_label is not None: | |
id2label_mapping[model_label] = label | |
return id2label_mapping, dataset_labels | |
def check_column_mapping_keys_validity(column_mapping, ppl): | |
# get the element in all the list elements | |
column_mapping = json.loads(column_mapping) | |
if "data" not in column_mapping.keys(): | |
return True | |
user_labels = set([pair[0] for pair in column_mapping["data"]]) | |
model_labels = set([pair[1] for pair in column_mapping["data"]]) | |
id2label = ppl.model.config.id2label | |
original_labels = set(id2label.values()) | |
return user_labels == model_labels == original_labels | |
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split): | |
# We assume dataset is ok here | |
ds = datasets.load_dataset(d_id, config)[split] | |
try: | |
dataset_features = ds.features | |
except AttributeError: | |
# Dataset does not have features, need to provide everything | |
return None, None, None | |
# Check whether we need to infer the text input column | |
infer_text_input_column = True | |
if "text" in column_mapping.keys(): | |
dataset_text_column = column_mapping["text"] | |
if dataset_text_column in dataset_features.keys(): | |
infer_text_input_column = False | |
else: | |
logging.warning(f"Provided {dataset_text_column} is not in Dataset columns") | |
if infer_text_input_column: | |
# Try to retrieve one | |
candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"] | |
if len(candidates) > 0: | |
logging.debug(f"Candidates are {candidates}") | |
column_mapping["text"] = candidates[0] | |
else: | |
# Not found a text feature | |
return column_mapping, None, None | |
# Load dataset as DataFrame | |
df = ds.to_pandas() | |
# Retrieve all labels | |
id2label_mapping = {} | |
id2label = ppl.model.config.id2label | |
label2id = {v: k for k, v in id2label.items()} | |
prediction_input = None | |
prediction_result = None | |
try: | |
# Use the first item to test prediction | |
prediction_input = df.head(1).at[0, column_mapping["text"]] | |
results = ppl({"text": prediction_input}, top_k=None) | |
prediction_result = { | |
f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results | |
} | |
except Exception: | |
# Pipeline prediction failed, need to provide labels | |
return column_mapping, None, None | |
# Infer labels | |
id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features) | |
if "data" in column_mapping.keys(): | |
if isinstance(column_mapping["data"], list): | |
# Use the column mapping passed by user | |
for user_label, model_label in column_mapping["data"]: | |
id2label_mapping[model_label] = user_label | |
elif None in id2label_mapping.values(): | |
column_mapping["label"] = { | |
i: None for i in id2label.keys() | |
} | |
return column_mapping, prediction_result, None | |
prediction_result = { | |
f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results | |
} | |
id2label_df = pd.DataFrame({ | |
"Dataset Labels": dataset_labels, | |
"Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels], | |
}) | |
if "data" not in column_mapping.keys(): | |
# Column mapping should contain original model labels | |
column_mapping["label"] = { | |
str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels) | |
} | |
return column_mapping, prediction_input, prediction_result, id2label_df | |