import gradio as gr import uuid from io_utils import read_scanners, write_scanners, read_inference_type, write_inference_type, get_logs_file from wordings import INTRODUCTION_MD, CONFIRM_MAPPING_DETAILS_MD from text_classification_ui_helpers import try_submit, check_dataset_and_get_config, check_dataset_and_get_split, check_model_and_show_prediction, write_column_mapping_to_config MAX_LABELS = 20 MAX_FEATURES = 20 EXAMPLE_MODEL_ID = 'cardiffnlp/twitter-roberta-base-sentiment-latest' EXAMPLE_DATA_ID = 'tweet_eval' CONFIG_PATH='./config.yaml' def get_demo(demo): with gr.Row(): gr.Markdown(INTRODUCTION_MD) with gr.Row(): model_id_input = gr.Textbox( label="Hugging Face model id", placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)", ) dataset_id_input = gr.Textbox( label="Hugging Face Dataset id", placeholder=EXAMPLE_DATA_ID + " (press enter to confirm)", ) with gr.Row(): dataset_config_input = gr.Dropdown(label='Dataset Config', visible=False) dataset_split_input = gr.Dropdown(label='Dataset Split', visible=False) with gr.Row(): example_input = gr.Markdown('Example Input', visible=False) with gr.Row(): example_prediction = gr.Label(label='Model Prediction Sample', visible=False) with gr.Row(): with gr.Accordion(label='Label and Feature Mapping', visible=False, open=False) as column_mapping_accordion: with gr.Row(): gr.Markdown(CONFIRM_MAPPING_DETAILS_MD) column_mappings = [] with gr.Row(): with gr.Column(): for _ in range(MAX_LABELS): column_mappings.append(gr.Dropdown(visible=False)) with gr.Column(): for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES): column_mappings.append(gr.Dropdown(visible=False)) with gr.Accordion(label='Model Wrap Advance Config (optional)', open=False): run_local = gr.Checkbox(value=True, label="Run in this Space") use_inference = read_inference_type('./config.yaml') == 'hf_inference_api' run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API") with gr.Accordion(label='Scanner Advance Config (optional)', open=False): selected = read_scanners('./config.yaml') # currently we remove data_leakage from the default scanners # Reason: data_leakage barely raises any issues and takes too many requests # when using inference API, causing rate limit error scan_config = selected + ['data_leakage'] scanners = gr.CheckboxGroup(choices=scan_config, value=selected, label='Scan Settings', visible=True) with gr.Row(): run_btn = gr.Button( "Get Evaluation Result", variant="primary", interactive=True, size="lg", ) with gr.Row(): uid = uuid.uuid4() uid_label = gr.Textbox(label="Evaluation ID:", value=uid, visible=False, interactive=False) logs = gr.Textbox(label="Giskard Bot Evaluation Log:", visible=False) demo.load(get_logs_file, uid_label, logs, every=0.5) gr.on(triggers=[label.change for label in column_mappings], fn=write_column_mapping_to_config, inputs=[dataset_id_input, dataset_config_input, dataset_split_input, *column_mappings]) gr.on(triggers=[model_id_input.change, dataset_config_input.change, dataset_split_input.change], fn=check_model_and_show_prediction, inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input], outputs=[example_input, example_prediction, column_mapping_accordion, *column_mappings]) dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input) dataset_config_input.change( check_dataset_and_get_split, inputs=[dataset_id_input, dataset_config_input], outputs=[dataset_split_input]) scanners.change( write_scanners, inputs=scanners ) run_inference.change( write_inference_type, inputs=[run_inference] ) gr.on( triggers=[ run_btn.click, ], fn=try_submit, inputs=[ model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, run_local, uid_label], outputs=[run_btn, logs]) def enable_run_btn(): return (gr.update(interactive=True)) gr.on( triggers=[ model_id_input.change, dataset_config_input.change, dataset_split_input.change, run_inference.change, run_local.change, scanners.change], fn=enable_run_btn, inputs=None, outputs=[run_btn]) gr.on( triggers=[label.change for label in column_mappings], fn=enable_run_btn, inputs=None, outputs=[run_btn])