giskard-evaluator / text_classification.py
inoki-giskard's picture
GSK-2396-allow-edit-feature-mappings (#12)
3a0ee14
raw
history blame
5.72 kB
import datasets
import logging
import json
import pandas as pd
def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
for model_label in id2label_mapping.keys():
if model_label.upper() == label.upper():
return model_label, label
return None, label
def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
id2label_mapping = {id2label[k]: None for k in id2label.keys()}
dataset_labels = None
for feature in dataset_features.values():
if not isinstance(feature, datasets.ClassLabel):
continue
if len(feature.names) != len(id2label_mapping.keys()):
continue
dataset_labels = feature.names
# Try to match labels
for label in feature.names:
if label in id2label_mapping.keys():
model_label = label
else:
# Try to find case unsensative
model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
if model_label is not None:
id2label_mapping[model_label] = label
else:
print(f"Label {label} is not found in model labels")
return id2label_mapping, dataset_labels
'''
params:
column_mapping: dict
example: {
"text": "sentences",
"label": {
"label0": "LABEL_0",
"label1": "LABEL_1"
}
}
ppl: pipeline
'''
def check_column_mapping_keys_validity(column_mapping, ppl):
# get the element in all the list elements
column_mapping = json.loads(column_mapping)
if "data" not in column_mapping.keys():
return True
user_labels = set([pair[0] for pair in column_mapping["data"]])
model_labels = set([pair[1] for pair in column_mapping["data"]])
id2label = ppl.model.config.id2label
original_labels = set(id2label.values())
return user_labels == model_labels == original_labels
def infer_text_input_column(column_mapping, dataset_features):
# Check whether we need to infer the text input column
infer_text_input_column = True
feature_map_df = None
if "text" in column_mapping.keys():
dataset_text_column = column_mapping["text"]
if dataset_text_column in dataset_features.keys():
infer_text_input_column = False
else:
logging.warning(f"Provided {dataset_text_column} is not in Dataset columns")
if infer_text_input_column:
# Try to retrieve one
candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
feature_map_df = pd.DataFrame({
"Dataset Features": [candidates[0]],
"Model Input Features": ["text"]
})
if len(candidates) > 0:
logging.debug(f"Candidates are {candidates}")
column_mapping["text"] = candidates[0]
return column_mapping, feature_map_df
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
# We assume dataset is ok here
ds = datasets.load_dataset(d_id, config)[split]
try:
dataset_features = ds.features
except AttributeError:
# Dataset does not have features, need to provide everything
return None, None, None, None, None
column_mapping, feature_map_df = infer_text_input_column(column_mapping, dataset_features)
# Load dataset as DataFrame
df = ds.to_pandas()
# Retrieve all labels
id2label_mapping = {}
id2label = ppl.model.config.id2label
label2id = {v: k for k, v in id2label.items()}
# Infer labels
id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
id2label_mapping_dataset_model = {
v: k for k, v in id2label_mapping.items()
}
if "data" in column_mapping.keys():
if isinstance(column_mapping["data"], list):
# Use the column mapping passed by user
for user_label, model_label in column_mapping["data"]:
id2label_mapping[model_label] = user_label
elif None in id2label_mapping.values():
column_mapping["label"] = {
i: None for i in id2label.keys()
}
return column_mapping, None, None, None, feature_map_df
id2label_df = pd.DataFrame({
"Dataset Labels": dataset_labels,
"Model Prediction Labels": [id2label_mapping_dataset_model[label] for label in dataset_labels],
})
# get a sample prediction from the model on the dataset
prediction_input = None
prediction_result = None
try:
# Use the first item to test prediction
prediction_input = df.head(1).at[0, column_mapping["text"]]
results = ppl({"text": prediction_input}, top_k=None)
prediction_result = {
f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
}
except Exception as e:
# Pipeline prediction failed, need to provide labels
print(e, '>>>> error')
return column_mapping, prediction_input, None, id2label_df, feature_map_df
prediction_result = {
f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results
}
if "data" not in column_mapping.keys():
# Column mapping should contain original model labels
column_mapping["label"] = {
str(i): id2label_mapping_dataset_model[label] for i, label in zip(id2label.keys(), dataset_labels)
}
return column_mapping, prediction_input, prediction_result, id2label_df, feature_map_df