import collections import logging import threading import uuid import datasets import gradio as gr import pandas as pd import leaderboard from io_utils import read_column_mapping, write_column_mapping from run_jobs import save_job_to_pipe from text_classification import ( strip_model_id_from_url, check_model_task, preload_hf_inference_api, get_example_prediction, get_labels_and_features_from_dataset, check_hf_token_validity, HuggingFaceInferenceAPIResponse, ) from wordings import ( CHECK_CONFIG_OR_SPLIT_RAW, CONFIRM_MAPPING_DETAILS_FAIL_RAW, MAPPING_STYLED_ERROR_WARNING, NOT_TEXT_CLASSIFICATION_MODEL_RAW, UNMATCHED_MODEL_DATASET_STYLED_ERROR, CHECK_LOG_SECTION_RAW, VALIDATED_MODEL_DATASET_STYLED, get_dataset_fetch_error_raw, ) import os MAX_LABELS = 40 MAX_FEATURES = 20 ds_dict = None ds_config = None def get_related_datasets_from_leaderboard(model_id): records = leaderboard.records model_id = strip_model_id_from_url(model_id) model_records = records[records["model_id"] == model_id] datasets_unique = list(model_records["dataset_id"].unique()) if len(datasets_unique) == 0: return gr.update(choices=[]) return gr.update(choices=datasets_unique) logger = logging.getLogger(__file__) def get_dataset_splits(dataset_id, dataset_config): try: splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True) return gr.update(choices=splits, value=splits[0], visible=True) except Exception as e: logger.warn(f"Check your dataset {dataset_id} and config {dataset_config}: {e}") return gr.update(visible=False) def check_dataset(dataset_id): logger.info(f"Loading {dataset_id}") try: configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True) if len(configs) == 0: return ( gr.update(visible=False), gr.update(visible=False), "" ) splits = datasets.get_dataset_split_names(dataset_id, configs[0], trust_remote_code=True) return ( gr.update(choices=configs, value=configs[0], visible=True), gr.update(choices=splits, value=splits[0], visible=True), "" ) except Exception as e: logger.warn(f"Check your dataset {dataset_id}: {e}") if "doesn't exist" in str(e): gr.Warning(get_dataset_fetch_error_raw(e)) if "forbidden" in str(e).lower(): # GSK-2770 gr.Warning(get_dataset_fetch_error_raw(e)) return ( gr.update(visible=False), gr.update(visible=False), "" ) def empty_column_mapping(uid): write_column_mapping(None, uid) def write_column_mapping_to_config(uid, *labels): # TODO: Substitute 'text' with more features for zero-shot # we are not using ds features because we only support "text" for now all_mappings = read_column_mapping(uid) if labels is None: return all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS]) all_mappings = export_mappings( all_mappings, "features", ["text"], labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)], ) write_column_mapping(all_mappings, uid) def export_mappings(all_mappings, key, subkeys, values): if key not in all_mappings.keys(): all_mappings[key] = dict() if subkeys is None: subkeys = list(all_mappings[key].keys()) if not subkeys: logging.debug(f"subkeys is empty for {key}") return all_mappings for i, subkey in enumerate(subkeys): if subkey: all_mappings[key][subkey] = values[i % len(values)] return all_mappings def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid): all_mappings = read_column_mapping(uid) # For flattened raw datasets with no labels # check if there are shared labels between model and dataset shared_labels = set(model_labels).intersection(set(ds_labels)) if shared_labels: ds_labels = list(shared_labels) if len(ds_labels) > MAX_LABELS: ds_labels = ds_labels[:MAX_LABELS] gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}") # sort labels to make sure the order is consistent # prediction gives the order based on probability ds_labels.sort() model_labels.sort() lables = [ gr.Dropdown( label=f"{label}", choices=model_labels, value=model_labels[i % len(model_labels)], interactive=True, visible=True, ) for i, label in enumerate(ds_labels) ] lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels) # TODO: Substitute 'text' with more features for zero-shot features = [ gr.Dropdown( label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True, ) for feature in ["text"] ] features += [ gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) ] all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features) write_column_mapping(all_mappings, uid) return lables + features def precheck_model_ds_enable_example_btn( model_id, dataset_id, dataset_config, dataset_split ): model_id = strip_model_id_from_url(model_id) model_task = check_model_task(model_id) preload_hf_inference_api(model_id) if dataset_config is None or dataset_split is None or len(dataset_config) == 0: return ( gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) try: ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True) df: pd.DataFrame = ds[dataset_split].to_pandas().head(5) ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split]) if model_task is None or model_task != "text-classification": gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW) return ( gr.update(interactive=False), gr.update(value=df, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) if not isinstance(ds_labels, list) or not isinstance(ds_features, list): gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) return ( gr.update(interactive=False), gr.update(value=df, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) return ( gr.update(interactive=True), gr.update(value=df, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) except Exception as e: # Config or split wrong logger.warn(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}") return ( gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) def align_columns_and_show_prediction( model_id, dataset_id, dataset_config, dataset_split, uid, run_inference, inference_token, ): model_id = strip_model_id_from_url(model_id) model_task = check_model_task(model_id) if model_task is None or model_task != "text-classification": gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW) return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), "", *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], ) dropdown_placement = [ gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) ] hf_token = os.environ.get("HF_WRITE_TOKEN", default="") prediction_input, prediction_response = get_example_prediction( model_id, dataset_id, dataset_config, dataset_split, hf_token ) if prediction_input is None or prediction_response is None: return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), "", *dropdown_placement, ) if isinstance(prediction_response, HuggingFaceInferenceAPIResponse): return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), f"Hugging Face Inference API is loading your model. {prediction_response.message}", *dropdown_placement, ) model_labels = list(prediction_response.keys()) ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True) ds_labels, ds_features = get_labels_and_features_from_dataset(ds) # when dataset does not have labels or features if not isinstance(ds_labels, list) or not isinstance(ds_features, list): gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), "", *dropdown_placement, ) if len(ds_labels) != len(model_labels): return ( gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), "", *dropdown_placement, ) column_mappings = list_labels_and_features_from_dataset( ds_labels, ds_features, model_labels, uid, ) # when labels or features are not aligned # show manually column mapping if ( collections.Counter(model_labels) != collections.Counter(ds_labels) or ds_features[0] != "text" ): return ( gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, open=True), gr.update(interactive=(run_inference and inference_token != "")), "", *column_mappings, ) return ( gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True), gr.update(value=prediction_input, lines=len(prediction_input)//225 + 1, visible=True), gr.update(value=prediction_response, visible=True), gr.update(visible=True, open=False), gr.update(interactive=(run_inference and inference_token != "")), "", *column_mappings, ) def check_column_mapping_keys_validity(all_mappings): if all_mappings is None: gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return False if "labels" not in all_mappings.keys(): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return False return True def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, dataset_config, dataset_split): if not run_inference or inference_token == "": logger.warn("Inference API is not enabled") return gr.update(interactive=False) if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "": logger.warn("Model id or dataset id is not selected") return gr.update(interactive=False) all_mappings = read_column_mapping(uid) if not check_column_mapping_keys_validity(all_mappings): logger.warn("Column mapping is not valid") return gr.update(interactive=False) if not check_hf_token_validity(inference_token): logger.warn("HF token is not valid") return gr.update(interactive=False) return gr.update(interactive=True) def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features): label_mapping = {} if len(all_mappings["labels"].keys()) != len(ds_labels): logger.warn("Label mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW) if len(all_mappings["features"].keys()) != len(ds_features): logger.warn("Feature mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW) for i, label in zip(range(len(ds_labels)), ds_labels): # align the saved labels with dataset labels order label_mapping.update({str(i): all_mappings["labels"][label]}) if "features" not in all_mappings.keys(): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) feature_mapping = all_mappings["features"] return label_mapping, feature_mapping def show_hf_token_info(token): valid = check_hf_token_validity(token) if not valid: return gr.update(visible=True) return gr.update(visible=False) def try_submit(m_id, d_id, config, split, inference, inference_token, uid): all_mappings = read_column_mapping(uid) if not check_column_mapping_keys_validity(all_mappings): return (gr.update(interactive=True), gr.update(visible=False)) # get ds labels and features again for alignment ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True) ds_labels, ds_features = get_labels_and_features_from_dataset(ds) label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features) eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" save_job_to_pipe( uid, ( m_id, d_id, config, split, inference, inference_token, uid, label_mapping, feature_mapping, ), eval_str, threading.Lock(), ) gr.Info("Your evaluation has been submitted") return ( gr.update(interactive=False), # Submit button gr.update(value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ", lines=5, visible=True, interactive=False), uuid.uuid4(), # Allocate a new uuid )