giskard-evaluator / app_text_classification.py
inoki-giskard's picture
Format with black and fix import
3573a39
raw
history blame
5.47 kB
import gradio as gr
import uuid
from io_utils import (
read_scanners,
write_scanners,
read_inference_type,
write_inference_type,
get_logs_file,
)
from wordings import INTRODUCTION_MD, CONFIRM_MAPPING_DETAILS_MD
from text_classification_ui_helpers import (
try_submit,
check_dataset_and_get_config,
check_dataset_and_get_split,
check_model_and_show_prediction,
write_column_mapping_to_config,
)
MAX_LABELS = 20
MAX_FEATURES = 20
EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
EXAMPLE_DATA_ID = "tweet_eval"
CONFIG_PATH = "./config.yaml"
def get_demo(demo):
with gr.Row():
gr.Markdown(INTRODUCTION_MD)
with gr.Row():
model_id_input = gr.Textbox(
label="Hugging Face model id",
placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)",
)
dataset_id_input = gr.Textbox(
label="Hugging Face Dataset id",
placeholder=EXAMPLE_DATA_ID + " (press enter to confirm)",
)
with gr.Row():
dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
with gr.Row():
example_input = gr.Markdown("Example Input", visible=False)
with gr.Row():
example_prediction = gr.Label(label="Model Prediction Sample", visible=False)
with gr.Row():
with gr.Accordion(
label="Label and Feature Mapping", visible=False, open=False
) as column_mapping_accordion:
with gr.Row():
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD)
column_mappings = []
with gr.Row():
with gr.Column():
for _ in range(MAX_LABELS):
column_mappings.append(gr.Dropdown(visible=False))
with gr.Column():
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
column_mappings.append(gr.Dropdown(visible=False))
with gr.Accordion(label="Model Wrap Advance Config (optional)", open=False):
run_local = gr.Checkbox(value=True, label="Run in this Space")
use_inference = read_inference_type(CONFIG_PATH) == "hf_inference_api"
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API")
with gr.Accordion(label="Scanner Advance Config (optional)", open=False):
selected = read_scanners(CONFIG_PATH)
# currently we remove data_leakage from the default scanners
# Reason: data_leakage barely raises any issues and takes too many requests
# when using inference API, causing rate limit error
scan_config = selected + ["data_leakage"]
scanners = gr.CheckboxGroup(
choices=scan_config, value=selected, label="Scan Settings", visible=True
)
with gr.Row():
run_btn = gr.Button(
"Get Evaluation Result",
variant="primary",
interactive=True,
size="lg",
)
with gr.Row():
uid = uuid.uuid4()
uid_label = gr.Textbox(
label="Evaluation ID:", value=uid, visible=False, interactive=False
)
logs = gr.Textbox(label="Giskard Bot Evaluation Log:", visible=False)
demo.load(get_logs_file, uid_label, logs, every=0.5)
gr.on(
triggers=[label.change for label in column_mappings],
fn=write_column_mapping_to_config,
inputs=[
dataset_id_input,
dataset_config_input,
dataset_split_input,
*column_mappings,
],
)
gr.on(
triggers=[
model_id_input.change,
dataset_config_input.change,
dataset_split_input.change,
],
fn=check_model_and_show_prediction,
inputs=[
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input,
],
outputs=[
example_input,
example_prediction,
column_mapping_accordion,
*column_mappings,
],
)
dataset_id_input.blur(
check_dataset_and_get_config, dataset_id_input, dataset_config_input
)
dataset_config_input.change(
check_dataset_and_get_split,
inputs=[dataset_id_input, dataset_config_input],
outputs=[dataset_split_input],
)
scanners.change(write_scanners, inputs=scanners)
run_inference.change(write_inference_type, inputs=[run_inference])
gr.on(
triggers=[
run_btn.click,
],
fn=try_submit,
inputs=[
model_id_input,
dataset_id_input,
dataset_config_input,
dataset_split_input,
run_local,
uid_label,
],
outputs=[run_btn, logs],
)
def enable_run_btn():
return gr.update(interactive=True)
gr.on(
triggers=[
model_id_input.change,
dataset_config_input.change,
dataset_split_input.change,
run_inference.change,
run_local.change,
scanners.change,
],
fn=enable_run_btn,
inputs=None,
outputs=[run_btn],
)
gr.on(
triggers=[label.change for label in column_mappings],
fn=enable_run_btn,
inputs=None,
outputs=[run_btn],
)