giskard-evaluator / text_classification_ui_helpers.py
inoki-giskard's picture
GSK-2434 Add component to show logs (#17)
8f809e2
raw
history blame
No virus
7.26 kB
import gradio as gr
from wordings import CONFIRM_MAPPING_DETAILS_FAIL_RAW
import json
import os
import logging
import threading
from io_utils import read_column_mapping, write_column_mapping, save_job_to_pipe, write_log_to_user_file
import datasets
import collections
from text_classification import get_labels_and_features_from_dataset, check_model, get_example_prediction
from transformers.pipelines import TextClassificationPipeline
MAX_LABELS = 20
MAX_FEATURES = 20
HF_REPO_ID = 'HF_REPO_ID'
HF_SPACE_ID = 'SPACE_ID'
HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
CONFIG_PATH = "./config.yaml"
def check_dataset_and_get_config(dataset_id):
try:
write_column_mapping(None)
configs = datasets.get_dataset_config_names(dataset_id)
return gr.Dropdown(configs, value=configs[0], visible=True)
except Exception:
# Dataset may not exist
pass
def check_dataset_and_get_split(dataset_id, dataset_config):
try:
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
return gr.Dropdown(splits, value=splits[0], visible=True)
except Exception:
# Dataset may not exist
# gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
pass
def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels):
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
if labels is None:
return
labels = [*labels]
all_mappings = read_column_mapping(CONFIG_PATH)
if all_mappings is None:
all_mappings = dict()
if "labels" not in all_mappings.keys():
all_mappings["labels"] = dict()
for i, label in enumerate(labels[:MAX_LABELS]):
if label:
all_mappings["labels"][label] = ds_labels[i]
if "features" not in all_mappings.keys():
all_mappings["features"] = dict()
for i, feat in enumerate(labels[MAX_LABELS:(MAX_LABELS + MAX_FEATURES)]):
if feat:
all_mappings["features"][feat] = ds_features[i]
write_column_mapping(all_mappings)
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
model_labels = list(model_id2label.values())
len_model_labels = len(model_labels)
print(model_labels, model_id2label, 3%len_model_labels)
lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i%len_model_labels], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])]
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
# TODO: Substitute 'text' with more features for zero-shot
features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in ['text']]
features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))]
return lables + features
def check_model_and_show_prediction(model_id, dataset_id, dataset_config, dataset_split):
ppl = check_model(model_id)
if ppl is None or not isinstance(ppl, TextClassificationPipeline):
gr.Warning("Please check your model.")
return (
gr.update(visible=False),
gr.update(visible=False),
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
)
dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
if ppl is None: # pipeline not found
gr.Warning("Model not found")
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
*dropdown_placement
)
model_id2label = ppl.model.config.id2label
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
# when dataset does not have labels or features
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
# gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False, open=False),
*dropdown_placement
)
column_mappings = list_labels_and_features_from_dataset(
ds_labels,
ds_features,
model_id2label,
)
# when labels or features are not aligned
# show manually column mapping
if collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) or ds_features[0] != 'text':
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True, open=True),
*column_mappings
)
prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split)
return (
gr.update(value=prediction_input, visible=True),
gr.update(value=prediction_output, visible=True),
gr.update(visible=True, open=False),
*column_mappings
)
def try_submit(m_id, d_id, config, split, local, uid):
all_mappings = read_column_mapping(CONFIG_PATH)
if all_mappings is None:
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return (gr.update(interactive=True), gr.update(visible=False))
if "labels" not in all_mappings.keys():
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return (gr.update(interactive=True), gr.update(visible=False))
label_mapping = all_mappings["labels"]
if "features" not in all_mappings.keys():
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return (gr.update(interactive=True), gr.update(visible=False))
feature_mapping = all_mappings["features"]
# TODO: Set column mapping for some dataset such as `amazon_polarity`
if local:
command = [
"python",
"cli.py",
"--loader", "huggingface",
"--model", m_id,
"--dataset", d_id,
"--dataset_config", config,
"--dataset_split", split,
"--hf_token", os.environ.get(HF_WRITE_TOKEN),
"--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
"--output_format", "markdown",
"--output_portal", "huggingface",
"--feature_mapping", json.dumps(feature_mapping),
"--label_mapping", json.dumps(label_mapping),
"--scan_config", "../config.yaml",
]
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
logging.info(f"Start local evaluation on {eval_str}")
save_job_to_pipe(uid, command, threading.Lock())
write_log_to_user_file(uid, f"Start local evaluation on {eval_str}. Please wait for your job to start...\n")
gr.Info(f"Start local evaluation on {eval_str}")
return (
gr.update(interactive=False),
gr.update(lines=5, visible=True, interactive=False))
else:
gr.Info("TODO: Submit task to an endpoint")
return (gr.update(interactive=True), # Submit button
gr.update(visible=False))