|
import collections |
|
import logging |
|
import threading |
|
import uuid |
|
|
|
import datasets |
|
import gradio as gr |
|
import pandas as pd |
|
|
|
import leaderboard |
|
from io_utils import ( |
|
read_column_mapping, |
|
write_column_mapping, |
|
read_scanners, |
|
write_scanners, |
|
) |
|
from run_jobs import save_job_to_pipe |
|
from text_classification import ( |
|
strip_model_id_from_url, |
|
check_model_task, |
|
preload_hf_inference_api, |
|
get_example_prediction, |
|
get_labels_and_features_from_dataset, |
|
check_hf_token_validity, |
|
HuggingFaceInferenceAPIResponse, |
|
) |
|
from wordings import ( |
|
CHECK_CONFIG_OR_SPLIT_RAW, |
|
CONFIRM_MAPPING_DETAILS_FAIL_RAW, |
|
MAPPING_STYLED_ERROR_WARNING, |
|
NOT_TEXT_CLASSIFICATION_MODEL_RAW, |
|
UNMATCHED_MODEL_DATASET_STYLED_ERROR, |
|
CHECK_LOG_SECTION_RAW, |
|
VALIDATED_MODEL_DATASET_STYLED, |
|
get_dataset_fetch_error_raw, |
|
) |
|
import os |
|
from app_env import HF_WRITE_TOKEN |
|
|
|
MAX_LABELS = 40 |
|
MAX_FEATURES = 20 |
|
|
|
ds_dict = None |
|
ds_config = None |
|
|
|
def get_related_datasets_from_leaderboard(model_id): |
|
records = leaderboard.records |
|
model_id = strip_model_id_from_url(model_id) |
|
model_records = records[records["model_id"] == model_id] |
|
datasets_unique = list(model_records["dataset_id"].unique()) |
|
|
|
if len(datasets_unique) == 0: |
|
return gr.update(choices=[]) |
|
|
|
return gr.update(choices=datasets_unique) |
|
|
|
|
|
logger = logging.getLogger(__file__) |
|
|
|
def get_dataset_splits(dataset_id, dataset_config): |
|
try: |
|
splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True) |
|
return gr.update(choices=splits, value=splits[0], visible=True) |
|
except Exception as e: |
|
logger.warn(f"Check your dataset {dataset_id} and config {dataset_config}: {e}") |
|
return gr.update(visible=False) |
|
|
|
def check_dataset(dataset_id): |
|
logger.info(f"Loading {dataset_id}") |
|
try: |
|
configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True) |
|
if len(configs) == 0: |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
"" |
|
) |
|
splits = datasets.get_dataset_split_names(dataset_id, configs[0], trust_remote_code=True) |
|
return ( |
|
gr.update(choices=configs, value=configs[0], visible=True), |
|
gr.update(choices=splits, value=splits[0], visible=True), |
|
"" |
|
) |
|
except Exception as e: |
|
logger.warn(f"Check your dataset {dataset_id}: {e}") |
|
if "doesn't exist" in str(e): |
|
gr.Warning(get_dataset_fetch_error_raw(e)) |
|
if "forbidden" in str(e).lower(): |
|
gr.Warning(get_dataset_fetch_error_raw(e)) |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
"" |
|
) |
|
|
|
def empty_column_mapping(uid): |
|
write_column_mapping(None, uid) |
|
|
|
def write_column_mapping_to_config(uid, *labels): |
|
|
|
|
|
all_mappings = read_column_mapping(uid) |
|
|
|
if labels is None: |
|
return |
|
all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS]) |
|
all_mappings = export_mappings( |
|
all_mappings, |
|
"features", |
|
["text"], |
|
labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)], |
|
) |
|
|
|
write_column_mapping(all_mappings, uid) |
|
|
|
def export_mappings(all_mappings, key, subkeys, values): |
|
if key not in all_mappings.keys(): |
|
all_mappings[key] = dict() |
|
if subkeys is None: |
|
subkeys = list(all_mappings[key].keys()) |
|
|
|
if not subkeys: |
|
logging.debug(f"subkeys is empty for {key}") |
|
return all_mappings |
|
|
|
for i, subkey in enumerate(subkeys): |
|
if subkey: |
|
all_mappings[key][subkey] = values[i % len(values)] |
|
return all_mappings |
|
|
|
|
|
def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid): |
|
all_mappings = read_column_mapping(uid) |
|
|
|
|
|
shared_labels = set(model_labels).intersection(set(ds_labels)) |
|
if shared_labels: |
|
ds_labels = list(shared_labels) |
|
if len(ds_labels) > MAX_LABELS: |
|
ds_labels = ds_labels[:MAX_LABELS] |
|
gr.Warning(f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd.") |
|
|
|
|
|
|
|
ds_labels.sort() |
|
model_labels.sort() |
|
|
|
lables = [ |
|
gr.Dropdown( |
|
label=f"{label}", |
|
choices=model_labels, |
|
value=model_labels[i % len(model_labels)], |
|
interactive=True, |
|
visible=True, |
|
) |
|
for i, label in enumerate(ds_labels) |
|
] |
|
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] |
|
all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels) |
|
|
|
|
|
features = [ |
|
gr.Dropdown( |
|
label=f"{feature}", |
|
choices=ds_features, |
|
value=ds_features[0], |
|
interactive=True, |
|
visible=True, |
|
) |
|
for feature in ["text"] |
|
] |
|
features += [ |
|
gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) |
|
] |
|
all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features) |
|
write_column_mapping(all_mappings, uid) |
|
|
|
return lables + features |
|
|
|
|
|
def precheck_model_ds_enable_example_btn( |
|
model_id, dataset_id, dataset_config, dataset_split |
|
): |
|
model_id = strip_model_id_from_url(model_id) |
|
model_task = check_model_task(model_id) |
|
preload_hf_inference_api(model_id) |
|
|
|
if dataset_config is None or dataset_split is None or len(dataset_config) == 0: |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
try: |
|
ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True) |
|
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5) |
|
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds[dataset_split]) |
|
|
|
if model_task is None or model_task != "text-classification": |
|
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW) |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(value=df, visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
if not isinstance(ds_labels, list) or not isinstance(ds_features, list): |
|
gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(value=df, visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
return ( |
|
gr.update(interactive=True), |
|
gr.update(value=df, visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
except Exception as e: |
|
|
|
logger.warn(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}") |
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
def align_columns_and_show_prediction( |
|
model_id, |
|
dataset_id, |
|
dataset_config, |
|
dataset_split, |
|
uid, |
|
profile: gr.OAuthProfile | None, |
|
oauth_token: gr.OAuthToken | None, |
|
): |
|
model_id = strip_model_id_from_url(model_id) |
|
model_task = check_model_task(model_id) |
|
if model_task is None or model_task != "text-classification": |
|
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW) |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False, open=False), |
|
gr.update(interactive=False), |
|
"", |
|
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], |
|
) |
|
|
|
dropdown_placement = [ |
|
gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) |
|
] |
|
|
|
hf_token = os.environ.get(HF_WRITE_TOKEN, default="") |
|
|
|
prediction_input, prediction_response = get_example_prediction( |
|
model_id, dataset_id, dataset_config, dataset_split, hf_token |
|
) |
|
|
|
if prediction_input is None or prediction_response is None: |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False, open=False), |
|
gr.update(interactive=False), |
|
"", |
|
*dropdown_placement, |
|
) |
|
|
|
if isinstance(prediction_response, HuggingFaceInferenceAPIResponse): |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False, open=False), |
|
gr.update(interactive=False), |
|
f"Hugging Face Inference API is loading your model. {prediction_response.message}", |
|
*dropdown_placement, |
|
) |
|
|
|
model_labels = list(prediction_response.keys()) |
|
|
|
ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True) |
|
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds) |
|
|
|
|
|
if not isinstance(ds_labels, list) or not isinstance(ds_features, list): |
|
gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False, open=False), |
|
gr.update(interactive=False), |
|
"", |
|
*dropdown_placement, |
|
) |
|
|
|
if len(ds_labels) != len(model_labels): |
|
return ( |
|
gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False, open=False), |
|
gr.update(interactive=False), |
|
"", |
|
*dropdown_placement, |
|
) |
|
|
|
column_mappings = list_labels_and_features_from_dataset( |
|
ds_labels, |
|
ds_features, |
|
model_labels, |
|
uid, |
|
) |
|
|
|
|
|
|
|
if ( |
|
collections.Counter(model_labels) != collections.Counter(ds_labels) |
|
or ds_features[0] != "text" |
|
): |
|
return ( |
|
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True), |
|
gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True), |
|
gr.update(value=prediction_response, visible=True), |
|
gr.update(visible=True, open=True), |
|
gr.update(interactive=(profile is not None and oauth_token is not None)), |
|
"", |
|
*column_mappings, |
|
) |
|
|
|
return ( |
|
gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True), |
|
gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True), |
|
gr.update(value=prediction_response, visible=True), |
|
gr.update(visible=True, open=False), |
|
gr.update(interactive=(profile is not None and oauth_token is not None)), |
|
"", |
|
*column_mappings, |
|
) |
|
|
|
|
|
def check_column_mapping_keys_validity(all_mappings): |
|
if all_mappings is None: |
|
logger.warning("all_mapping is None") |
|
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) |
|
return False |
|
|
|
if "labels" not in all_mappings.keys(): |
|
logger.warning(f"Label mapping is not valid, all_mappings: {all_mappings}") |
|
return False |
|
|
|
return True |
|
|
|
def enable_run_btn(uid, model_id, dataset_id, dataset_config, dataset_split, profile: gr.OAuthProfile | None, oath_token: gr.OAuthToken | None): |
|
if profile is None: |
|
return gr.update(interactive=False) |
|
if oath_token is None: |
|
return gr.update(interactive=False) |
|
if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "": |
|
logger.warn("Model id or dataset id is not selected") |
|
return gr.update(interactive=False) |
|
|
|
all_mappings = read_column_mapping(uid) |
|
if not check_column_mapping_keys_validity(all_mappings): |
|
logger.warn("Column mapping is not valid") |
|
return gr.update(interactive=False) |
|
|
|
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None): |
|
label_mapping = {} |
|
if len(all_mappings["labels"].keys()) != len(ds_labels): |
|
logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}. |
|
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""") |
|
|
|
if len(all_mappings["features"].keys()) != len(ds_features): |
|
logger.warn(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}. |
|
\nall_mappings: {all_mappings}\nds_features: {ds_features}""") |
|
|
|
for i, label in zip(range(len(ds_labels)), ds_labels): |
|
|
|
label_mapping.update({str(i): all_mappings["labels"][label]}) |
|
|
|
if "features" not in all_mappings.keys(): |
|
logger.warning("features not in all_mappings") |
|
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) |
|
|
|
feature_mapping = all_mappings["features"] |
|
if len(label_keys) > 0: |
|
feature_mapping.update({"label": label_keys[0]}) |
|
return label_mapping, feature_mapping |
|
|
|
def show_hf_token_info(token): |
|
valid = check_hf_token_validity(token) |
|
if not valid: |
|
return gr.update(visible=True) |
|
return gr.update(visible=False) |
|
|
|
def try_submit(m_id, d_id, config, split, uid, profile: gr.OAuthProfile | None, oath_token: gr.OAuthToken | None): |
|
all_mappings = read_column_mapping(uid) |
|
if not check_column_mapping_keys_validity(all_mappings): |
|
return (gr.update(interactive=True), gr.update(visible=False)) |
|
|
|
|
|
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True) |
|
ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds) |
|
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys) |
|
|
|
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" |
|
save_job_to_pipe( |
|
uid, |
|
( |
|
m_id, |
|
d_id, |
|
config, |
|
split, |
|
oath_token.token, |
|
uid, |
|
label_mapping, |
|
feature_mapping, |
|
), |
|
eval_str, |
|
threading.Lock(), |
|
) |
|
gr.Info("Your evaluation has been submitted") |
|
|
|
new_uid = uuid.uuid4() |
|
scanners = read_scanners(uid) |
|
write_scanners(scanners, new_uid) |
|
|
|
return ( |
|
gr.update(interactive=False), |
|
gr.update(value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ", lines=5, visible=True, interactive=False), |
|
new_uid, |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|