Spaces:
Sleeping
Sleeping
import uuid | |
import gradio as gr | |
from io_utils import read_scanners, write_scanners | |
from text_classification_ui_helpers import ( | |
get_related_datasets_from_leaderboard, | |
align_columns_and_show_prediction, | |
check_dataset, | |
precheck_model_ds_enable_example_btn, | |
show_hf_token_info, | |
try_submit, | |
write_column_mapping_to_config, | |
) | |
from text_classification import ( | |
get_example_prediction, | |
HuggingFaceInferenceAPIResponse | |
) | |
from wordings import ( | |
CONFIRM_MAPPING_DETAILS_MD, | |
INTRODUCTION_MD, | |
USE_INFERENCE_API_TIP, | |
CHECK_LOG_SECTION_RAW, | |
HF_TOKEN_INVALID_STYLED | |
) | |
MAX_LABELS = 40 | |
MAX_FEATURES = 20 | |
EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest" | |
CONFIG_PATH = "./config.yaml" | |
def get_demo(): | |
with gr.Row(): | |
gr.Markdown(INTRODUCTION_MD) | |
uid_label = gr.Textbox( | |
label="Evaluation ID:", value=uuid.uuid4, visible=False, interactive=False | |
) | |
with gr.Row(): | |
model_id_input = gr.Textbox( | |
label="Hugging Face model id", | |
placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)", | |
) | |
with gr.Column(): | |
dataset_id_input = gr.Dropdown( | |
choices=[], | |
value="", | |
allow_custom_value=True, | |
label="Hugging Face Dataset id", | |
) | |
with gr.Row(): | |
dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False, allow_custom_value=True) | |
dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False, allow_custom_value=True) | |
with gr.Row(): | |
first_line_ds = gr.DataFrame(label="Dataset preview", visible=False) | |
with gr.Row(): | |
loading_status = gr.HTML(visible=True) | |
with gr.Row(): | |
example_btn = gr.Button( | |
"Validate model & dataset", | |
visible=True, | |
variant="primary", | |
interactive=False, | |
) | |
with gr.Row(): | |
example_input = gr.HTML(visible=False) | |
with gr.Row(): | |
example_prediction = gr.Label(label="Model Prediction Sample", visible=False) | |
with gr.Row(): | |
with gr.Accordion( | |
label="Label and Feature Mapping", visible=False, open=False | |
) as column_mapping_accordion: | |
with gr.Row(): | |
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD) | |
column_mappings = [] | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# Label Mapping") | |
for _ in range(MAX_LABELS): | |
column_mappings.append(gr.Dropdown(visible=False)) | |
with gr.Column(): | |
gr.Markdown("# Feature Mapping") | |
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES): | |
column_mappings.append(gr.Dropdown(visible=False)) | |
with gr.Accordion(label="Model Wrap Advance Config", open=True): | |
gr.HTML(USE_INFERENCE_API_TIP) | |
run_inference = gr.Checkbox(value=True, label="Run with Inference API") | |
inference_token = gr.Textbox( | |
placeholder="hf-xxxxxxxxxxxxxxxxxxxx", | |
value="", | |
label="HF Token for Inference API", | |
visible=True, | |
interactive=True, | |
) | |
inference_token_info = gr.HTML(value=HF_TOKEN_INVALID_STYLED, visible=False) | |
inference_token.change( | |
fn=show_hf_token_info, | |
inputs=[inference_token], | |
outputs=[inference_token_info], | |
) | |
with gr.Accordion(label="Scanner Advance Config (optional)", open=False): | |
scanners = gr.CheckboxGroup(label="Scan Settings", visible=True) | |
def get_scanners(uid): | |
selected = read_scanners(uid) | |
# we remove data_leakage from the default scanners | |
# Reason: data_leakage barely raises any issues and takes too many requests | |
# when using inference API, causing rate limit error | |
scan_config = selected + ["data_leakage"] | |
return gr.update( | |
choices=scan_config, value=selected, label="Scan Settings", visible=True | |
) | |
with gr.Row(): | |
run_btn = gr.Button( | |
"Get Evaluation Result", | |
variant="primary", | |
interactive=False, | |
size="lg", | |
) | |
with gr.Row(): | |
logs = gr.Textbox( | |
value=CHECK_LOG_SECTION_RAW, | |
label="Giskard Bot Evaluation Guide:", | |
visible=False, | |
every=0.5, | |
) | |
scanners.change(write_scanners, inputs=[scanners, uid_label]) | |
gr.on( | |
triggers=[model_id_input.change], | |
fn=get_related_datasets_from_leaderboard, | |
inputs=[model_id_input], | |
outputs=[dataset_id_input], | |
).then( | |
fn=check_dataset, | |
inputs=[dataset_id_input], | |
outputs=[dataset_config_input, dataset_split_input, loading_status] | |
) | |
gr.on( | |
triggers=[dataset_id_input.change], | |
fn=check_dataset, | |
inputs=[dataset_id_input], | |
outputs=[dataset_config_input, dataset_split_input, loading_status] | |
) | |
gr.on( | |
triggers=[label.change for label in column_mappings], | |
fn=write_column_mapping_to_config, | |
inputs=[ | |
uid_label, | |
*column_mappings, | |
], | |
) | |
# label.change sometimes does not pass the changed value | |
gr.on( | |
triggers=[label.input for label in column_mappings], | |
fn=write_column_mapping_to_config, | |
inputs=[ | |
uid_label, | |
*column_mappings, | |
], | |
) | |
gr.on( | |
triggers=[ | |
model_id_input.change, | |
dataset_id_input.change, | |
dataset_config_input.change, | |
dataset_split_input.change, | |
], | |
fn=precheck_model_ds_enable_example_btn, | |
inputs=[ | |
model_id_input, | |
dataset_id_input, | |
dataset_config_input, | |
dataset_split_input, | |
], | |
outputs=[example_btn, first_line_ds, loading_status], | |
) | |
gr.on( | |
triggers=[ | |
example_btn.click, | |
], | |
fn=align_columns_and_show_prediction, | |
inputs=[ | |
model_id_input, | |
dataset_id_input, | |
dataset_config_input, | |
dataset_split_input, | |
uid_label, | |
run_inference, | |
inference_token, | |
], | |
outputs=[ | |
example_input, | |
example_prediction, | |
column_mapping_accordion, | |
run_btn, | |
loading_status, | |
*column_mappings, | |
], | |
) | |
gr.on( | |
triggers=[ | |
run_btn.click, | |
], | |
fn=try_submit, | |
inputs=[ | |
model_id_input, | |
dataset_id_input, | |
dataset_config_input, | |
dataset_split_input, | |
run_inference, | |
inference_token, | |
uid_label, | |
], | |
outputs=[run_btn, logs, uid_label], | |
) | |
def enable_run_btn(run_inference, inference_token, model_id, dataset_id, dataset_config, dataset_split): | |
if not run_inference or inference_token == "": | |
return gr.update(interactive=False) | |
if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "": | |
return gr.update(interactive=False) | |
if not column_mapping_accordion.visible: | |
return gr.update(interactive=False) | |
if inference_token_info.visible: | |
return gr.update(interactive=False) | |
return gr.update(interactive=True) | |
gr.on( | |
triggers=[ | |
run_inference.input, | |
inference_token.input, | |
scanners.input, | |
], | |
fn=enable_run_btn, | |
inputs=[ | |
run_inference, | |
inference_token, | |
model_id_input, | |
dataset_id_input, | |
dataset_config_input, | |
dataset_split_input | |
], | |
outputs=[run_btn], | |
) | |
gr.on( | |
triggers=[label.input for label in column_mappings], | |
fn=enable_run_btn, | |
inputs=[ | |
run_inference, | |
inference_token, | |
model_id_input, | |
dataset_id_input, | |
dataset_config_input, | |
dataset_split_input | |
], # FIXME | |
outputs=[run_btn], | |
) | |