import collections import json import logging import os import threading import uuid import datasets import gradio as gr import pandas as pd from transformers.pipelines import TextClassificationPipeline from io_utils import ( get_yaml_path, read_column_mapping, save_job_to_pipe, write_column_mapping, write_log_to_user_file, ) from text_classification import ( check_model, get_example_prediction, get_labels_and_features_from_dataset, ) from wordings import ( CHECK_CONFIG_OR_SPLIT_RAW, CONFIRM_MAPPING_DETAILS_FAIL_RAW, MAPPING_STYLED_ERROR_WARNING, get_styled_input, ) MAX_LABELS = 40 MAX_FEATURES = 20 HF_REPO_ID = "HF_REPO_ID" HF_SPACE_ID = "SPACE_ID" HF_WRITE_TOKEN = "HF_WRITE_TOKEN" HF_GSK_HUB_URL = "GSK_HUB_URL" HF_GSK_HUB_PROJECT_KEY = "GSK_HUB_PROJECT_KEY" HF_GSK_HUB_KEY = "GSK_API_KEY" HF_GSK_HUB_HF_TOKEN = "GSK_HF_TOKEN" HF_GSK_HUB_UNLOCK_TOKEN = "GSK_HUB_UNLOCK_TOKEN" LEADERBOARD = "giskard-bot/evaluator-leaderboard" logger = logging.getLogger(__file__) def check_dataset(dataset_id, dataset_config=None, dataset_split=None): configs = ["default"] splits = ["default"] logger.info(f"Loading {dataset_id}, {dataset_config}, {dataset_split}") try: configs = datasets.get_dataset_config_names(dataset_id) splits = list( datasets.load_dataset( dataset_id, configs[0] if not dataset_config else dataset_config ).keys() ) if dataset_config == None: dataset_config = configs[0] dataset_split = splits[0] elif dataset_split == None: dataset_split = splits[0] except Exception as e: # Dataset may not exist logger.warn( f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}" ) if dataset_config == None: return ( gr.Dropdown(configs, value=configs[0], visible=True), gr.Dropdown(splits, value=splits[0], visible=True), gr.DataFrame(pd.DataFrame(), visible=False), "", ) elif dataset_split == None: return ( gr.Dropdown(configs, value=dataset_config, visible=True), gr.Dropdown(splits, value=splits[0], visible=True), gr.DataFrame(pd.DataFrame(), visible=False), "", ) dataset_dict = datasets.load_dataset(dataset_id, dataset_config) dataframe: pd.DataFrame = dataset_dict[dataset_split].to_pandas().head(5) return ( gr.Dropdown(configs, value=dataset_config, visible=True), gr.Dropdown(splits, value=dataset_split, visible=True), gr.DataFrame(dataframe, visible=True), "", ) def select_run_mode(run_inf): if run_inf: return (gr.update(visible=True), gr.update(value=False)) else: return (gr.update(visible=False), gr.update(value=True)) def deselect_run_inference(run_local): if run_local: return (gr.update(visible=False), gr.update(value=False)) else: return (gr.update(visible=True), gr.update(value=True)) def write_column_mapping_to_config(uid, *labels): # TODO: Substitute 'text' with more features for zero-shot # we are not using ds features because we only support "text" for now all_mappings = read_column_mapping(uid) if labels is None: return all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS]) all_mappings = export_mappings( all_mappings, "features", ["text"], labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)], ) write_column_mapping(all_mappings, uid) def export_mappings(all_mappings, key, subkeys, values): if key not in all_mappings.keys(): all_mappings[key] = dict() if subkeys is None: subkeys = list(all_mappings[key].keys()) if not subkeys: logging.debug(f"subkeys is empty for {key}") return all_mappings for i, subkey in enumerate(subkeys): if subkey: all_mappings[key][subkey] = values[i % len(values)] return all_mappings def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid): model_labels = list(model_id2label.values()) all_mappings = read_column_mapping(uid) # For flattened raw datasets with no labels # check if there are shared labels between model and dataset shared_labels = set(model_labels).intersection(set(ds_labels)) if shared_labels: ds_labels = list(shared_labels) if len(ds_labels) > MAX_LABELS: ds_labels = ds_labels[:MAX_LABELS] gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}") ds_labels.sort() model_labels.sort() lables = [ gr.Dropdown( label=f"{label}", choices=model_labels, value=model_id2label[i % len(model_labels)], interactive=True, visible=True, ) for i, label in enumerate(ds_labels) ] lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels) # TODO: Substitute 'text' with more features for zero-shot features = [ gr.Dropdown( label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True, ) for feature in ["text"] ] features += [ gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) ] all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features) write_column_mapping(all_mappings, uid) return lables + features def precheck_model_ds_enable_example_btn( model_id, dataset_id, dataset_config, dataset_split ): ppl = check_model(model_id) if ppl is None or not isinstance(ppl, TextClassificationPipeline): gr.Warning("Please check your model.") return gr.update(interactive=False), "" ds_labels, ds_features = get_labels_and_features_from_dataset( dataset_id, dataset_config, dataset_split ) if not isinstance(ds_labels, list) or not isinstance(ds_features, list): gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) return gr.update(interactive=False), "" return gr.update(interactive=True), "" def align_columns_and_show_prediction( model_id, dataset_id, dataset_config, dataset_split, uid ): ppl = check_model(model_id) if ppl is None or not isinstance(ppl, TextClassificationPipeline): gr.Warning("Please check your model.") return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), "", *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], ) dropdown_placement = [ gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) ] if ppl is None: # pipeline not found gr.Warning("Model not found") return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), *dropdown_placement, ) model_id2label = ppl.model.config.id2label ds_labels, ds_features = get_labels_and_features_from_dataset( dataset_id, dataset_config, dataset_split ) # when dataset does not have labels or features if not isinstance(ds_labels, list) or not isinstance(ds_features, list): gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, open=False), gr.update(interactive=False), "", *dropdown_placement, ) column_mappings = list_labels_and_features_from_dataset( ds_labels, ds_features, model_id2label, uid, ) # when labels or features are not aligned # show manually column mapping if ( collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) or ds_features[0] != "text" ): return ( gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True), gr.update(visible=False), gr.update(visible=True, open=True), gr.update(interactive=True), "", *column_mappings, ) prediction_input, prediction_output = get_example_prediction( ppl, dataset_id, dataset_config, dataset_split ) return ( gr.update(value=get_styled_input(prediction_input), visible=True), gr.update(value=prediction_output, visible=True), gr.update(visible=True, open=False), gr.update(interactive=True), "", *column_mappings, ) def check_column_mapping_keys_validity(all_mappings): if all_mappings is None: gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return (gr.update(interactive=True), gr.update(visible=False)) if "labels" not in all_mappings.keys(): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return (gr.update(interactive=True), gr.update(visible=False)) def construct_label_and_feature_mapping(all_mappings): label_mapping = {} for i, label in zip( range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys() ): label_mapping.update({str(i): label}) if "features" not in all_mappings.keys(): gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) return (gr.update(interactive=True), gr.update(visible=False)) feature_mapping = all_mappings["features"] return label_mapping, feature_mapping def try_submit(m_id, d_id, config, split, local, inference, inference_token, uid): all_mappings = read_column_mapping(uid) check_column_mapping_keys_validity(all_mappings) label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings) leaderboard_dataset = None if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator": leaderboard_dataset = LEADERBOARD if local: inference_type = "hf_pipeline" if inference and inference_token: inference_type = "hf_inference_api" # TODO: Set column mapping for some dataset such as `amazon_polarity` command = [ "giskard_scanner", "--loader", "huggingface", "--model", m_id, "--dataset", d_id, "--dataset_config", config, "--dataset_split", split, "--output_format", "markdown", "--output_portal", "huggingface", "--feature_mapping", json.dumps(feature_mapping), "--label_mapping", json.dumps(label_mapping), "--scan_config", get_yaml_path(uid), "--inference_type", inference_type, "--inference_api_token", inference_token, ] # The token to publish post if os.environ.get(HF_WRITE_TOKEN): command.append("--hf_token") command.append(os.environ.get(HF_WRITE_TOKEN)) # The repo to publish post if os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID): command.append("--discussion_repo") # TODO: Replace by the model id command.append(os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID)) # The repo to publish for ranking if leaderboard_dataset: command.append("--leaderboard_dataset") command.append(leaderboard_dataset) # The info to upload to Giskard hub if os.environ.get(HF_GSK_HUB_KEY): command.append("--giskard_hub_api_key") command.append(os.environ.get(HF_GSK_HUB_KEY)) if os.environ.get(HF_GSK_HUB_URL): command.append("--giskard_hub_url") command.append(os.environ.get(HF_GSK_HUB_URL)) if os.environ.get(HF_GSK_HUB_PROJECT_KEY): command.append("--giskard_hub_project_key") command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY)) if os.environ.get(HF_GSK_HUB_HF_TOKEN): command.append("--giskard_hub_hf_token") command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN)) if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN): command.append("--giskard_hub_unlock_token") command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN)) eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" logging.info(f"Start local evaluation on {eval_str}") save_job_to_pipe(uid, command, eval_str, threading.Lock()) write_log_to_user_file( uid, f"Start local evaluation on {eval_str}. Please wait for your job to start...\n", ) gr.Info(f"Start local evaluation on {eval_str}") return ( gr.update(interactive=False), # Submit button gr.update(lines=5, visible=True, interactive=False), uuid.uuid4(), # Allocate a new uuid )