import datasets import logging import json import pandas as pd def text_classificaiton_match_label_case_unsensative(id2label_mapping, label): for model_label in id2label_mapping.keys(): if model_label.upper() == label.upper(): return model_label, label return None, label def text_classification_map_model_and_dataset_labels(id2label, dataset_features): id2label_mapping = {id2label[k]: None for k in id2label.keys()} dataset_labels = None for feature in dataset_features.values(): if not isinstance(feature, datasets.ClassLabel): continue if len(feature.names) != len(id2label_mapping.keys()): continue dataset_labels = feature.names # Try to match labels for label in feature.names: if label in id2label_mapping.keys(): model_label = label else: # Try to find case unsensative model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label) if model_label is not None: id2label_mapping[model_label] = label return id2label_mapping, dataset_labels def check_column_mapping_keys_validity(column_mapping, ppl): # get the element in all the list elements column_mapping = json.loads(column_mapping) if "data" not in column_mapping.keys(): return True user_labels = set([pair[0] for pair in column_mapping["data"]]) model_labels = set([pair[1] for pair in column_mapping["data"]]) id2label = ppl.model.config.id2label original_labels = set(id2label.values()) return user_labels == model_labels == original_labels def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split): # We assume dataset is ok here ds = datasets.load_dataset(d_id, config)[split] try: dataset_features = ds.features except AttributeError: # Dataset does not have features, need to provide everything return None, None, None # Check whether we need to infer the text input column infer_text_input_column = True if "text" in column_mapping.keys(): dataset_text_column = column_mapping["text"] if dataset_text_column in dataset_features.keys(): infer_text_input_column = False else: logging.warning(f"Provided {dataset_text_column} is not in Dataset columns") if infer_text_input_column: # Try to retrieve one candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"] if len(candidates) > 0: logging.debug(f"Candidates are {candidates}") column_mapping["text"] = candidates[0] else: # Not found a text feature return column_mapping, None, None # Load dataset as DataFrame df = ds.to_pandas() # Retrieve all labels id2label_mapping = {} id2label = ppl.model.config.id2label label2id = {v: k for k, v in id2label.items()} prediction_input = None prediction_result = None try: # Use the first item to test prediction prediction_input = df.head(1).at[0, column_mapping["text"]] results = ppl({"text": prediction_input}, top_k=None) prediction_result = { f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results } except Exception: # Pipeline prediction failed, need to provide labels return column_mapping, None, None # Infer labels id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features) id2label_mapping_dataset_model = { v: k for k, v in id2label_mapping.items() } if "data" in column_mapping.keys(): if isinstance(column_mapping["data"], list): # Use the column mapping passed by user for user_label, model_label in column_mapping["data"]: id2label_mapping[model_label] = user_label elif None in id2label_mapping.values(): column_mapping["label"] = { i: None for i in id2label.keys() } return column_mapping, prediction_result, None prediction_result = { f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results } id2label_df = pd.DataFrame({ "Dataset Labels": dataset_labels, "Model Prediction Labels": [id2label_mapping_dataset_model[label] for label in dataset_labels], }) if "data" not in column_mapping.keys(): # Column mapping should contain original model labels column_mapping["label"] = { str(i): id2label_mapping_dataset_model[label] for i, label in zip(id2label.keys(), dataset_labels) } return column_mapping, prediction_input, prediction_result, id2label_df