giskard-evaluator / app_leaderboard.py
inoki-giskard's picture
GSK-2498-suggest-a-dataset-for-model (#46)
8c47a22 verified
raw
history blame
4.88 kB
import logging
import datasets
import gradio as gr
from fetch_utils import (check_dataset_and_get_config,
check_dataset_and_get_split)
from text_classification_ui_helpers import LEADERBOARD
import leaderboard
def get_records_from_dataset_repo(dataset_id):
dataset_config = check_dataset_and_get_config(dataset_id)
logging.info(f"Dataset {dataset_id} has configs {dataset_config}")
dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
logging.info(f"Dataset {dataset_id} has splits {dataset_split}")
try:
ds = datasets.load_dataset(dataset_id, dataset_config[0])[dataset_split[0]]
df = ds.to_pandas()
return df
except Exception as e:
logging.warning(
f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
)
return None
def get_model_ids(ds):
logging.info(f"Dataset {ds} column names: {ds['model_id']}")
models = ds["model_id"].tolist()
# return unique elements in the list model_ids
model_ids = list(set(models))
model_ids.insert(0, "Any")
return model_ids
def get_dataset_ids(ds):
logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
datasets = ds["dataset_id"].tolist()
dataset_ids = list(set(datasets))
dataset_ids.insert(0, "Any")
return dataset_ids
def get_types(ds):
# set types for each column
types = [str(t) for t in ds.dtypes.to_list()]
types = [t.replace("object", "markdown") for t in types]
types = [t.replace("float64", "number") for t in types]
types = [t.replace("int64", "number") for t in types]
return types
def get_display_df(df):
# style all elements in the model_id column
display_df = df.copy()
columns = display_df.columns.tolist()
if "model_id" in columns:
display_df["model_id"] = display_df["model_id"].apply(
lambda x: f'<a href="https://huggingface.co/{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
)
# style all elements in the dataset_id column
if "dataset_id" in columns:
display_df["dataset_id"] = display_df["dataset_id"].apply(
lambda x: f'<a href="https://huggingface.co/datasets/{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
)
# style all elements in the report_link column
if "report_link" in columns:
display_df["report_link"] = display_df["report_link"].apply(
lambda x: f'<a href="{x}" target="_blank" style="color:blue">πŸ”—{x}</a>'
)
return display_df
def get_demo():
leaderboard.records = get_records_from_dataset_repo(LEADERBOARD)
records = leaderboard.records
model_ids = get_model_ids(records)
dataset_ids = get_dataset_ids(records)
column_names = records.columns.tolist()
default_columns = ["model_id", "dataset_id", "total_issues", "report_link"]
default_df = records[default_columns] # extract columns selected
types = get_types(default_df)
display_df = get_display_df(default_df) # the styled dataframe to display
with gr.Row():
task_select = gr.Dropdown(
label="Task",
choices=["text_classification", "tabular"],
value="text_classification",
interactive=True,
)
model_select = gr.Dropdown(
label="Model id", choices=model_ids, value=model_ids[0], interactive=True
)
dataset_select = gr.Dropdown(
label="Dataset id",
choices=dataset_ids,
value=dataset_ids[0],
interactive=True,
)
with gr.Row():
columns_select = gr.CheckboxGroup(
label="Show columns",
choices=column_names,
value=default_columns,
interactive=True,
)
with gr.Row():
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
@gr.on(
triggers=[
model_select.change,
dataset_select.change,
columns_select.change,
task_select.change,
],
inputs=[model_select, dataset_select, columns_select, task_select],
outputs=[leaderboard_df],
)
def filter_table(model_id, dataset_id, columns, task):
records = leaderboard.records
# filter the table based on task
df = records[(records["task"] == task)]
# filter the table based on the model_id and dataset_id
if model_id and model_id != "Any":
df = df[(df["model_id"] == model_id)]
if dataset_id and dataset_id != "Any":
df = df[(df["dataset_id"] == dataset_id)]
# filter the table based on the columns
df = df[columns]
types = get_types(df)
display_df = get_display_df(df)
return gr.update(value=display_df, datatype=types, interactive=False)