Spaces:
Sleeping
Sleeping
import collections | |
import json | |
import logging | |
import os | |
import threading | |
import datasets | |
import gradio as gr | |
from transformers.pipelines import TextClassificationPipeline | |
from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe, | |
write_column_mapping, write_log_to_user_file) | |
from text_classification import (check_model, get_example_prediction, | |
get_labels_and_features_from_dataset) | |
from wordings import CONFIRM_MAPPING_DETAILS_FAIL_RAW | |
MAX_LABELS = 20 | |
MAX_FEATURES = 20 | |
HF_REPO_ID = "HF_REPO_ID" | |
HF_SPACE_ID = "SPACE_ID" | |
HF_WRITE_TOKEN = "HF_WRITE_TOKEN" | |
def check_dataset_and_get_config(dataset_id, uid): | |
try: | |
write_column_mapping(None, uid) # reset column mapping | |
configs = datasets.get_dataset_config_names(dataset_id) | |
return gr.Dropdown(configs, value=configs[0], visible=True) | |
except Exception: | |
# Dataset may not exist | |
pass | |
def check_dataset_and_get_split(dataset_id, dataset_config): | |
try: | |
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys()) | |
return gr.Dropdown(splits, value=splits[0], visible=True) | |
except Exception: | |
# Dataset may not exist | |
# gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") | |
pass | |
def write_column_mapping_to_config( | |
dataset_id, dataset_config, dataset_split, uid, *labels | |
): | |
# TODO: Substitute 'text' with more features for zero-shot | |
# we are not using ds features because we only support "text" for now | |
ds_labels, _ = get_labels_and_features_from_dataset( | |
dataset_id, dataset_config, dataset_split | |
) | |
if labels is None: | |
return | |
labels = [*labels] | |
all_mappings = read_column_mapping(uid) | |
if all_mappings is None: | |
all_mappings = dict() | |
if "labels" not in all_mappings.keys(): | |
all_mappings["labels"] = dict() | |
for i, label in enumerate(labels[:MAX_LABELS]): | |
if label: | |
all_mappings["labels"][label] = ds_labels[i] | |
if "features" not in all_mappings.keys(): | |
all_mappings["features"] = dict() | |
for i, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]): | |
if feat: | |
# TODO: Substitute 'text' with more features for zero-shot | |
all_mappings["features"]["text"] = feat | |
write_column_mapping(all_mappings, uid) | |
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label): | |
model_labels = list(model_id2label.values()) | |
len_model_labels = len(model_labels) | |
lables = [ | |
gr.Dropdown( | |
label=f"{label}", | |
choices=model_labels, | |
value=model_id2label[i % len_model_labels], | |
interactive=True, | |
visible=True, | |
) | |
for i, label in enumerate(ds_labels[:MAX_LABELS]) | |
] | |
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] | |
# TODO: Substitute 'text' with more features for zero-shot | |
features = [ | |
gr.Dropdown( | |
label=f"{feature}", | |
choices=ds_features, | |
value=ds_features[0], | |
interactive=True, | |
visible=True, | |
) | |
for feature in ["text"] | |
] | |
features += [ | |
gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) | |
] | |
return lables + features | |
def check_model_and_show_prediction( | |
model_id, dataset_id, dataset_config, dataset_split | |
): | |
ppl = check_model(model_id) | |
if ppl is None or not isinstance(ppl, TextClassificationPipeline): | |
gr.Warning("Please check your model.") | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], | |
) | |
dropdown_placement = [ | |
gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) | |
] | |
if ppl is None: # pipeline not found | |
gr.Warning("Model not found") | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False, open=False), | |
*dropdown_placement, | |
) | |
model_id2label = ppl.model.config.id2label | |
ds_labels, ds_features = get_labels_and_features_from_dataset( | |
dataset_id, dataset_config, dataset_split | |
) | |
# when dataset does not have labels or features | |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
# gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False, open=False), | |
*dropdown_placement, | |
) | |
column_mappings = list_labels_and_features_from_dataset( | |
ds_labels, | |
ds_features, | |
model_id2label, | |
) | |
# when labels or features are not aligned | |
# show manually column mapping | |
if ( | |
collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) | |
or ds_features[0] != "text" | |
): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True, open=True), | |
*column_mappings, | |
) | |
prediction_input, prediction_output = get_example_prediction( | |
ppl, dataset_id, dataset_config, dataset_split | |
) | |
return ( | |
gr.update(value=prediction_input, visible=True), | |
gr.update(value=prediction_output, visible=True), | |
gr.update(visible=True, open=False), | |
*column_mappings, | |
) | |
def try_submit(m_id, d_id, config, split, local, uid): | |
all_mappings = read_column_mapping(uid) | |
if all_mappings is None: | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return (gr.update(interactive=True), gr.update(visible=False)) | |
if "labels" not in all_mappings.keys(): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return (gr.update(interactive=True), gr.update(visible=False)) | |
label_mapping = {} | |
for i, label in zip(range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()): | |
label_mapping.update({str(i): label}) | |
if "features" not in all_mappings.keys(): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return (gr.update(interactive=True), gr.update(visible=False)) | |
feature_mapping = all_mappings["features"] | |
# TODO: Set column mapping for some dataset such as `amazon_polarity` | |
if local: | |
command = [ | |
"giskard_scanner", | |
"--loader", | |
"huggingface", | |
"--model", | |
m_id, | |
"--dataset", | |
d_id, | |
"--dataset_config", | |
config, | |
"--dataset_split", | |
split, | |
"--hf_token", | |
os.environ.get(HF_WRITE_TOKEN), | |
"--discussion_repo", | |
os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID), | |
"--output_format", | |
"markdown", | |
"--output_portal", | |
"huggingface", | |
"--feature_mapping", | |
json.dumps(feature_mapping), | |
"--label_mapping", | |
json.dumps(label_mapping), | |
"--scan_config", | |
get_yaml_path(uid), | |
] | |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" | |
logging.info(f"Start local evaluation on {eval_str}") | |
save_job_to_pipe(uid, command, threading.Lock()) | |
write_log_to_user_file( | |
uid, | |
f"Start local evaluation on {eval_str}. Please wait for your job to start...\n", | |
) | |
gr.Info(f"Start local evaluation on {eval_str}") | |
return ( | |
gr.update(interactive=False), | |
gr.update(lines=5, visible=True, interactive=False), | |
) | |
else: | |
gr.Info("TODO: Submit task to an endpoint") | |
return (gr.update(interactive=True), gr.update(visible=False)) # Submit button | |