giskard-evaluator / app_text_classification.py
ZeroCommand's picture
merge and resolve conflicts
96a1184
raw
history blame
8.18 kB
import uuid
import gradio as gr
from io_utils import read_scanners, write_scanners
from text_classification_ui_helpers import (
get_related_datasets_from_leaderboard,
align_columns_and_show_prediction,
get_dataset_splits,
check_dataset,
precheck_model_ds_enable_example_btn,
try_submit,
empty_column_mapping,
write_column_mapping_to_config,
enable_run_btn,
)
import logging
from wordings import (
CONFIRM_MAPPING_DETAILS_MD,
INTRODUCTION_MD,
LOG_IN_TIPS,
CHECK_LOG_SECTION_RAW,
)
MAX_LABELS = 40
MAX_FEATURES = 20
EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
CONFIG_PATH = "./config.yaml"
logger = logging.getLogger(__name__)
def get_demo():
with gr.Row():
gr.Markdown(INTRODUCTION_MD)
uid_label = gr.Textbox(
label="Evaluation ID:", value=uuid.uuid4, visible=False, interactive=False
)
with gr.Accordion(label="Login to Use This Space", open=True):
gr.HTML(LOG_IN_TIPS)
gr.LoginButton()
with gr.Row():
model_id_input = gr.Textbox(
label="Hugging Face Model id",
placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)",
)
with gr.Column():
dataset_id_input = gr.Dropdown(
choices=[],
value="",
allow_custom_value=True,
label="Hugging Face Dataset id",
)
with gr.Row():
dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False, allow_custom_value=True)
dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False, allow_custom_value=True)
with gr.Row():
first_line_ds = gr.DataFrame(label="Dataset Preview", visible=False)
with gr.Row():
loading_dataset_info = gr.HTML(visible=True)
with gr.Row():
example_btn = gr.Button(
"Validate Model & Dataset",
visible=True,
variant="primary",
interactive=False,
)
with gr.Row():
loading_validation = gr.HTML(visible=True)
with gr.Row():
validation_result = gr.HTML(visible=False)
with gr.Row():
example_input = gr.Textbox(label="Example Input", visible=False, interactive=False)
example_prediction = gr.Label(label="Model Sample Prediction", visible=False)
with gr.Row():
with gr.Accordion(
label="Label and Feature Mapping", visible=False, open=False
) as column_mapping_accordion:
with gr.Row():
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD)
column_mappings = []
with gr.Row():
with gr.Column():
gr.Markdown("# Label Mapping")
for _ in range(MAX_LABELS):
column_mappings.append(gr.Dropdown(visible=False))
with gr.Column():
gr.Markdown("# Feature Mapping")
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
column_mappings.append(gr.Dropdown(visible=False))
with gr.Accordion(label="Scanner Advanced Config (optional)", open=False):
scanners = gr.CheckboxGroup(visible=True)
@gr.on(triggers=[uid_label.change], inputs=[uid_label], outputs=[scanners])
def get_scanners(uid):
selected = read_scanners(uid)
# we remove data_leakage from the default scanners
# Reason: data_leakage barely raises any issues and takes too many requests
# when using inference API, causing rate limit error
scan_config = [
"ethical_bias",
"text_perturbation",
"robustness",
"performance",
"underconfidence",
"overconfidence",
"spurious_correlation",
"data_leakage",
]
return gr.update(
choices=scan_config, value=selected, label="Scan Settings", visible=True
)
with gr.Row():
run_btn = gr.Button(
"Get Evaluation Result",
variant="primary",
interactive=False,
size="lg",
)
with gr.Row():
logs = gr.Textbox(
value=CHECK_LOG_SECTION_RAW,
label="Log",
visible=False,
every=0.5,
)
scanners.change(write_scanners, inputs=[scanners, uid_label])
gr.on(
triggers=[model_id_input.change],
fn=get_related_datasets_from_leaderboard,
inputs=[model_id_input, dataset_id_input],
outputs=[dataset_id_input],
).then(
fn=check_dataset,
inputs=[dataset_id_input],
outputs=[dataset_config_input, dataset_split_input, loading_dataset_info],
)
gr.on(
triggers=[dataset_id_input.input, dataset_id_input.select],
fn=check_dataset,
inputs=[dataset_id_input],
outputs=[dataset_config_input, dataset_split_input, loading_dataset_info]
)
dataset_config_input.change(fn=get_dataset_splits, inputs=[dataset_id_input, dataset_config_input], outputs=[dataset_split_input])
gr.on(
triggers=[model_id_input.change, dataset_id_input.change, dataset_config_input.change],
fn=empty_column_mapping,
inputs=[uid_label]
)
gr.on(
triggers=[label.change for label in column_mappings],
fn=write_column_mapping_to_config,
inputs=[
uid_label,
*column_mappings,
],
)
# label.change sometimes does not pass the changed value
gr.on(
triggers=[label.input for label in column_mappings],
fn=write_column_mapping_to_config,
inputs=[
uid_label,
*column_mappings,
],
)
gr.on(
triggers=[
model_id_input.change,
model_id_input.input,
dataset_id_input.change,
dataset_config_input.change,
dataset_split_input.change,
],
fn=precheck_model_ds_enable_example_btn,
inputs=[
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input,
],
outputs=[
example_btn,
first_line_ds,
validation_result,
example_input,
example_prediction,
column_mapping_accordion,],
)
gr.on(
triggers=[
example_btn.click,
],
fn=align_columns_and_show_prediction,
inputs=[
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input,
uid_label,
],
outputs=[
validation_result,
example_input,
example_prediction,
column_mapping_accordion,
run_btn,
loading_validation,
*column_mappings,
],
)
gr.on(
triggers=[
run_btn.click,
],
fn=try_submit,
inputs=[
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input,
uid_label,
],
outputs=[
run_btn,
logs,
uid_label,
validation_result,
example_input,
example_prediction,
column_mapping_accordion,
],
)
gr.on(
triggers=[
scanners.input,
],
fn=enable_run_btn,
inputs=[
uid_label,
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input
],
outputs=[run_btn],
)
gr.on(
triggers=[label.input for label in column_mappings],
fn=enable_run_btn,
inputs=[
uid_label,
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input
], # FIXME
outputs=[run_btn],
)