Spaces:
Sleeping
Sleeping
import gradio as gr | |
import datasets | |
import os | |
import time | |
import subprocess | |
import logging | |
import json | |
from transformers.pipelines import TextClassificationPipeline | |
from text_classification import get_labels_and_features_from_dataset, check_model, get_example_prediction | |
from io_utils import read_scanners, write_scanners, read_inference_type, read_column_mapping, write_column_mapping, write_inference_type | |
from wordings import CONFIRM_MAPPING_DETAILS_MD, CONFIRM_MAPPING_DETAILS_FAIL_RAW | |
HF_REPO_ID = 'HF_REPO_ID' | |
HF_SPACE_ID = 'SPACE_ID' | |
HF_WRITE_TOKEN = 'HF_WRITE_TOKEN' | |
MAX_LABELS = 20 | |
MAX_FEATURES = 20 | |
EXAMPLE_MODEL_ID = 'cardiffnlp/twitter-roberta-base-sentiment-latest' | |
EXAMPLE_DATA_ID = 'tweet_eval' | |
CONFIG_PATH='./config.yaml' | |
def try_submit(m_id, d_id, config, split, local): | |
all_mappings = read_column_mapping(CONFIG_PATH) | |
if "labels" not in all_mappings.keys(): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return gr.update(interactive=True) | |
label_mapping = all_mappings["labels"] | |
if "features" not in all_mappings.keys(): | |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
return gr.update(interactive=True) | |
feature_mapping = all_mappings["features"] | |
# TODO: Set column mapping for some dataset such as `amazon_polarity` | |
if local: | |
command = [ | |
"python", | |
"cli.py", | |
"--loader", "huggingface", | |
"--model", m_id, | |
"--dataset", d_id, | |
"--dataset_config", config, | |
"--dataset_split", split, | |
"--hf_token", os.environ.get(HF_WRITE_TOKEN), | |
"--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID), | |
"--output_format", "markdown", | |
"--output_portal", "huggingface", | |
"--feature_mapping", json.dumps(feature_mapping), | |
"--label_mapping", json.dumps(label_mapping), | |
"--scan_config", "../config.yaml", | |
] | |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" | |
start = time.time() | |
logging.info(f"Start local evaluation on {eval_str}") | |
evaluator = subprocess.Popen( | |
command, | |
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "cicd"), | |
stderr=subprocess.STDOUT, | |
) | |
result = evaluator.wait() | |
logging.info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s") | |
gr.Info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s") | |
else: | |
gr.Info("TODO: Submit task to an endpoint") | |
return gr.update(interactive=True) # Submit button | |
def check_dataset_and_get_config(dataset_id): | |
try: | |
configs = datasets.get_dataset_config_names(dataset_id) | |
return gr.Dropdown(configs, value=configs[0], visible=True) | |
except Exception: | |
# Dataset may not exist | |
pass | |
def check_dataset_and_get_split(dataset_id, dataset_config): | |
try: | |
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys()) | |
return gr.Dropdown(splits, value=splits[0], visible=True) | |
except Exception: | |
# Dataset may not exist | |
# gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") | |
pass | |
def get_demo(): | |
with gr.Row(): | |
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD) | |
with gr.Row(): | |
model_id_input = gr.Textbox( | |
label="Hugging Face model id", | |
placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)", | |
) | |
dataset_id_input = gr.Textbox( | |
label="Hugging Face Dataset id", | |
placeholder=EXAMPLE_DATA_ID + " (press enter to confirm)", | |
) | |
with gr.Row(): | |
dataset_config_input = gr.Dropdown(label='Dataset Config', visible=False) | |
dataset_split_input = gr.Dropdown(label='Dataset Split', visible=False) | |
with gr.Row(): | |
example_input = gr.Markdown('Example Input', visible=False) | |
with gr.Row(): | |
example_prediction = gr.Label(label='Model Prediction Sample', visible=False) | |
with gr.Row(): | |
column_mappings = [] | |
with gr.Column(): | |
for _ in range(MAX_LABELS): | |
column_mappings.append(gr.Dropdown(visible=False)) | |
with gr.Column(): | |
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES): | |
column_mappings.append(gr.Dropdown(visible=False)) | |
with gr.Accordion(label='Model Wrap Advance Config (optional)', open=False): | |
run_local = gr.Checkbox(value=True, label="Run in this Space") | |
use_inference = read_inference_type('./config.yaml') == 'hf_inference_api' | |
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API") | |
with gr.Accordion(label='Scanner Advance Config (optional)', open=False): | |
selected = read_scanners('./config.yaml') | |
scan_config = selected + ['data_leakage'] | |
scanners = gr.CheckboxGroup(choices=scan_config, value=selected, label='Scan Settings', visible=True) | |
with gr.Row(): | |
run_btn = gr.Button( | |
"Get Evaluation Result", | |
variant="primary", | |
interactive=True, | |
size="lg", | |
) | |
def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels): | |
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split) | |
if labels is None: | |
return | |
labels = [*labels] | |
all_mappings = read_column_mapping(CONFIG_PATH) | |
if "labels" not in all_mappings.keys(): | |
all_mappings["labels"] = dict() | |
for i, label in enumerate(labels[:MAX_LABELS]): | |
if label: | |
all_mappings["labels"][label] = ds_labels[i] | |
if "features" not in all_mappings.keys(): | |
all_mappings["features"] = dict() | |
for i, feat in enumerate(labels[MAX_LABELS:(MAX_LABELS + MAX_FEATURES)]): | |
if feat: | |
all_mappings["features"][feat] = ds_features[i] | |
write_column_mapping(all_mappings) | |
def list_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split, model_id2label, model_features): | |
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split) | |
if ds_labels is None or ds_features is None: | |
return [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)] | |
model_labels = list(model_id2label.values()) | |
lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])] | |
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] | |
features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in model_features] | |
features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))] | |
return lables + features | |
def clear_column_mapping_config(): | |
write_column_mapping(None) | |
def check_model_and_show_prediction(model_id, dataset_id, dataset_config, dataset_split): | |
ppl = check_model(model_id) | |
if ppl is None or not isinstance(ppl, TextClassificationPipeline): | |
gr.Warning("Please check your model.") | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)] | |
) | |
model_id2label = ppl.model.config.id2label | |
model_features = ['text'] | |
column_mappings = list_labels_and_features_from_dataset( | |
dataset_id, | |
dataset_config, | |
dataset_split, | |
model_id2label, | |
model_features | |
) | |
if ppl is None: | |
gr.Warning("Model not found") | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
*column_mappings | |
) | |
prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split) | |
return ( | |
gr.update(value=prediction_input, visible=True), | |
gr.update(value=prediction_output, visible=True), | |
*column_mappings | |
) | |
dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input) | |
dataset_config_input.change( | |
check_dataset_and_get_split, | |
inputs=[dataset_id_input, dataset_config_input], | |
outputs=[dataset_split_input]) | |
run_inference.change( | |
write_inference_type, | |
inputs=[run_inference] | |
) | |
gr.on( | |
triggers=[ | |
run_btn.click, | |
], | |
fn=try_submit, | |
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, run_local], | |
outputs=[run_btn]) |