Spaces:
Running
Running
import json | |
import logging | |
import datasets | |
import huggingface_hub | |
import pandas as pd | |
from transformers import pipeline | |
import requests | |
import os | |
from app_env import HF_WRITE_TOKEN | |
logger = logging.getLogger(__name__) | |
AUTH_CHECK_URL = "https://huggingface.co/api/whoami-v2" | |
logger = logging.getLogger(__file__) | |
class HuggingFaceInferenceAPIResponse: | |
def __init__(self, message): | |
self.message = message | |
def get_labels_and_features_from_dataset(ds): | |
try: | |
dataset_features = ds.features | |
label_keys = [i for i in dataset_features.keys() if i.startswith('label')] | |
if len(label_keys) == 0: # no labels found | |
# return everything for post processing | |
return list(dataset_features.keys()), list(dataset_features.keys()) | |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel): | |
if hasattr(dataset_features[label_keys[0]], 'feature'): | |
label_feat = dataset_features[label_keys[0]].feature | |
labels = label_feat.names | |
else: | |
labels = dataset_features[label_keys[0]].names | |
features = [f for f in dataset_features.keys() if not f.startswith("label")] | |
return labels, features | |
except Exception as e: | |
logging.warning( | |
f"Get Labels/Features Failed for dataset: {e}" | |
) | |
return None, None | |
def check_model_task(model_id): | |
# check if model is valid on huggingface | |
try: | |
task = huggingface_hub.model_info(model_id).pipeline_tag | |
if task is None: | |
return None | |
return task | |
except Exception: | |
return None | |
def get_model_labels(model_id, example_input): | |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="") | |
payload = {"inputs": example_input, "options": {"use_cache": True}} | |
response = hf_inference_api(model_id, hf_token, payload) | |
if "error" in response: | |
return None | |
return extract_from_response(response, "label") | |
def extract_from_response(data, key): | |
results = [] | |
if isinstance(data, dict): | |
res = data.get(key) | |
if res is not None: | |
results.append(res) | |
for value in data.values(): | |
results.extend(extract_from_response(value, key)) | |
elif isinstance(data, list): | |
for element in data: | |
results.extend(extract_from_response(element, key)) | |
return results | |
def hf_inference_api(model_id, hf_token, payload): | |
hf_inference_api_endpoint = os.environ.get( | |
"HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co" | |
) | |
url = f"{hf_inference_api_endpoint}/models/{model_id}" | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
response = requests.post(url, headers=headers, json=payload) | |
if not hasattr(response, "status_code") or response.status_code != 200: | |
logger.warning(f"Request to inference API returns {response}") | |
try: | |
return response.json() | |
except Exception: | |
return {"error": response.content} | |
def preload_hf_inference_api(model_id): | |
payload = {"inputs": "This is a test", "options": {"use_cache": True, }} | |
hf_token = os.environ.get(HF_WRITE_TOKEN, default="") | |
hf_inference_api(model_id, hf_token, payload) | |
def check_model_pipeline(model_id): | |
try: | |
task = huggingface_hub.model_info(model_id).pipeline_tag | |
except Exception: | |
return None | |
try: | |
ppl = pipeline(task=task, model=model_id) | |
return ppl | |
except Exception: | |
return None | |
def text_classificaiton_match_label_case_unsensative(id2label_mapping, label): | |
for model_label in id2label_mapping.keys(): | |
if model_label.upper() == label.upper(): | |
return model_label, label | |
return None, label | |
def text_classification_map_model_and_dataset_labels(id2label, dataset_features): | |
id2label_mapping = {id2label[k]: None for k in id2label.keys()} | |
dataset_labels = None | |
for feature in dataset_features.values(): | |
if not isinstance(feature, datasets.ClassLabel): | |
continue | |
if len(feature.names) != len(id2label_mapping.keys()): | |
continue | |
dataset_labels = feature.names | |
# Try to match labels | |
for label in feature.names: | |
if label in id2label_mapping.keys(): | |
model_label = label | |
else: | |
# Try to find case unsensative | |
model_label, label = text_classificaiton_match_label_case_unsensative( | |
id2label_mapping, label | |
) | |
if model_label is not None: | |
id2label_mapping[model_label] = label | |
else: | |
print(f"Label {label} is not found in model labels") | |
return id2label_mapping, dataset_labels | |
""" | |
params: | |
column_mapping: dict | |
example: { | |
"text": "sentences", | |
"label": { | |
"label0": "LABEL_0", | |
"label1": "LABEL_1" | |
} | |
} | |
ppl: pipeline | |
""" | |
def check_column_mapping_keys_validity(column_mapping, ppl): | |
# get the element in all the list elements | |
column_mapping = json.loads(column_mapping) | |
if "data" not in column_mapping.keys(): | |
return True | |
user_labels = set([pair[0] for pair in column_mapping["data"]]) | |
model_labels = set([pair[1] for pair in column_mapping["data"]]) | |
id2label = ppl.model.config.id2label | |
original_labels = set(id2label.values()) | |
return user_labels == model_labels == original_labels | |
""" | |
params: | |
column_mapping: dict | |
dataset_features: dict | |
example: { | |
'text': Value(dtype='string', id=None), | |
'label': ClassLabel(names=['negative', 'neutral', 'positive'], id=None) | |
} | |
""" | |
def infer_text_input_column(column_mapping, dataset_features): | |
# Check whether we need to infer the text input column | |
infer_text_input_column = True | |
feature_map_df = None | |
if "text" in column_mapping.keys(): | |
dataset_text_column = column_mapping["text"] | |
if dataset_text_column in dataset_features.keys(): | |
infer_text_input_column = False | |
else: | |
logging.warning(f"Provided {dataset_text_column} is not in Dataset columns") | |
if infer_text_input_column: | |
# Try to retrieve one | |
candidates = [ | |
f for f in dataset_features if dataset_features[f].dtype == "string" | |
] | |
feature_map_df = pd.DataFrame( | |
{"Dataset Features": [candidates[0]], "Model Input Features": ["text"]} | |
) | |
if len(candidates) > 0: | |
logging.debug(f"Candidates are {candidates}") | |
column_mapping["text"] = candidates[0] | |
return column_mapping, feature_map_df | |
""" | |
params: | |
column_mapping: dict | |
id2label_mapping: dict | |
example: | |
id2label_mapping: { | |
'negative': 'negative', | |
'neutral': 'neutral', | |
'positive': 'positive' | |
} | |
""" | |
def infer_output_label_column( | |
column_mapping, id2label_mapping, id2label, dataset_labels | |
): | |
# Check whether we need to infer the output label column | |
if "data" in column_mapping.keys(): | |
if isinstance(column_mapping["data"], list): | |
# Use the column mapping passed by user | |
for user_label, model_label in column_mapping["data"]: | |
id2label_mapping[model_label] = user_label | |
elif None in id2label_mapping.values(): | |
column_mapping["label"] = {i: None for i in id2label.keys()} | |
return column_mapping, None | |
if "data" not in column_mapping.keys(): | |
# Column mapping should contain original model labels | |
column_mapping["label"] = { | |
str(i): id2label_mapping[label] | |
for i, label in zip(id2label.keys(), dataset_labels) | |
} | |
id2label_df = pd.DataFrame( | |
{ | |
"Dataset Labels": dataset_labels, | |
"Model Prediction Labels": [ | |
id2label_mapping[label] for label in dataset_labels | |
], | |
} | |
) | |
return column_mapping, id2label_df | |
def check_dataset_features_validity(d_id, config, split): | |
# We assume dataset is ok here | |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True) | |
try: | |
dataset_features = ds.features | |
except AttributeError: | |
# Dataset does not have features, need to provide everything | |
return None, None | |
# Load dataset as DataFrame | |
df = ds.to_pandas() | |
return df, dataset_features | |
def select_the_first_string_column(ds): | |
for feature in ds.features.keys(): | |
if isinstance(ds[0][feature], str): | |
return feature | |
return None | |
def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split, hf_token): | |
# get a sample prediction from the model on the dataset | |
prediction_input = None | |
prediction_result = None | |
try: | |
# Use the first item to test prediction | |
ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True) | |
if "text" not in ds.features.keys(): | |
# Dataset does not have text column | |
prediction_input = ds[0][select_the_first_string_column(ds)] | |
else: | |
prediction_input = ds[0]["text"] | |
payload = {"inputs": prediction_input, "options": {"use_cache": True}} | |
results = hf_inference_api(model_id, hf_token, payload) | |
if isinstance(results, dict) and "error" in results.keys(): | |
if "estimated_time" in results.keys(): | |
return prediction_input, HuggingFaceInferenceAPIResponse( | |
f"Estimated time: {int(results['estimated_time'])}s. Please try again later.") | |
return prediction_input, HuggingFaceInferenceAPIResponse( | |
f"Inference Error: {results['error']}.") | |
while isinstance(results, list): | |
if isinstance(results[0], dict): | |
break | |
results = results[0] | |
prediction_result = { | |
f'{result["label"]}': result["score"] for result in results | |
} | |
except Exception as e: | |
# inference api prediction failed, show the error message | |
logger.error(f"Get example prediction failed {e}") | |
return prediction_input, None | |
return prediction_input, prediction_result | |
def get_sample_prediction(ppl, df, column_mapping, id2label_mapping): | |
# get a sample prediction from the model on the dataset | |
prediction_input = None | |
prediction_result = None | |
try: | |
# Use the first item to test prediction | |
prediction_input = df.head(1).at[0, column_mapping["text"]] | |
results = ppl({"text": prediction_input}, top_k=None) | |
prediction_result = { | |
f'{result["label"]}': result["score"] for result in results | |
} | |
except Exception: | |
# Pipeline prediction failed, need to provide labels | |
return prediction_input, None | |
# Display results in original label and mapped label | |
prediction_result = { | |
f'{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result[ | |
"score" | |
] | |
for result in results | |
} | |
return prediction_input, prediction_result | |
def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split): | |
# load dataset as pd DataFrame | |
# get features column from dataset | |
df, dataset_features = check_dataset_features_validity(d_id, config, split) | |
column_mapping, feature_map_df = infer_text_input_column( | |
column_mapping, dataset_features | |
) | |
if feature_map_df is None: | |
# dataset does not have any features | |
return None, None, None, None, None | |
# Retrieve all labels | |
id2label = ppl.model.config.id2label | |
# Infer labels | |
id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels( | |
id2label, dataset_features | |
) | |
column_mapping, id2label_df = infer_output_label_column( | |
column_mapping, id2label_mapping, id2label, dataset_labels | |
) | |
if id2label_df is None: | |
# does not able to infer output label column | |
return column_mapping, None, None, None, feature_map_df | |
# Get a sample prediction | |
prediction_input, prediction_result = get_sample_prediction( | |
ppl, df, column_mapping, id2label_mapping | |
) | |
if prediction_result is None: | |
# does not able to get a sample prediction | |
return column_mapping, prediction_input, None, id2label_df, feature_map_df | |
return ( | |
column_mapping, | |
prediction_input, | |
prediction_result, | |
id2label_df, | |
feature_map_df, | |
) | |
def strip_model_id_from_url(model_id): | |
if model_id.startswith("https://huggingface.co/"): | |
return "/".join(model_id.split("/")[-2:]) | |
return model_id | |
def check_hf_token_validity(hf_token): | |
if hf_token == "": | |
return False | |
if not isinstance(hf_token, str): | |
return False | |
# use huggingface api to check the token | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
response = requests.get(AUTH_CHECK_URL, headers=headers) | |
if response.status_code != 200: | |
return False | |
return True |