Spaces:
Sleeping
Sleeping
File size: 4,202 Bytes
be473e6 8f809e2 be473e6 8f809e2 be473e6 8f809e2 be473e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import datasets
import logging
from fetch_utils import check_dataset_and_get_config, check_dataset_and_get_split
def get_records_from_dataset_repo(dataset_id):
dataset_config = check_dataset_and_get_config(dataset_id)
logging.info(f"Dataset {dataset_id} has configs {dataset_config}")
dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
logging.info(f"Dataset {dataset_id} has splits {dataset_split}")
try:
ds = datasets.load_dataset(dataset_id, dataset_config[0])[dataset_split[0]]
df = ds.to_pandas()
return df
except Exception as e:
logging.warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
return None
def get_model_ids(ds):
logging.info(f"Dataset {ds} column names: {ds['model_id']}")
models = ds['model_id'].tolist()
# return unique elements in the list model_ids
model_ids = list(set(models))
return model_ids
def get_dataset_ids(ds):
logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
datasets = ds['dataset_id'].tolist()
dataset_ids = list(set(datasets))
return dataset_ids
def get_types(ds):
# set types for each column
types = [str(t) for t in ds.dtypes.to_list()]
types = [t.replace('object', 'markdown') for t in types]
types = [t.replace('float64', 'number') for t in types]
types = [t.replace('int64', 'number') for t in types]
return types
def get_display_df(df):
# style all elements in the model_id column
display_df = df.copy()
columns = display_df.columns.tolist()
if 'model_id' in columns:
display_df['model_id'] = display_df['model_id'].apply(lambda x: f'<p href="https://huggingface.co/{x}" style="color:blue">π{x}</p>')
# style all elements in the dataset_id column
if 'dataset_id' in columns:
display_df['dataset_id'] = display_df['dataset_id'].apply(lambda x: f'<p href="https://huggingface.co/datasets/{x}" style="color:blue">π{x}</p>')
# style all elements in the report_link column
if 'report_link' in columns:
display_df['report_link'] = display_df['report_link'].apply(lambda x: f'<p href="{x}" style="color:blue">π{x}</p>')
return display_df
def get_demo():
records = get_records_from_dataset_repo('ZeroCommand/test-giskard-report')
model_ids = get_model_ids(records)
dataset_ids = get_dataset_ids(records)
column_names = records.columns.tolist()
default_columns = ['model_id', 'dataset_id', 'total_issues', 'report_link']
default_df = records[default_columns] # extract columns selected
types = get_types(default_df)
display_df = get_display_df(default_df) # the styled dataframe to display
with gr.Row():
task_select = gr.Dropdown(label='Task', choices=['text_classification', 'tabular'], value='text_classification', interactive=True)
model_select = gr.Dropdown(label='Model id', choices=model_ids, interactive=True)
dataset_select = gr.Dropdown(label='Dataset id', choices=dataset_ids, interactive=True)
with gr.Row():
columns_select = gr.CheckboxGroup(label='Show columns', choices=column_names, value=default_columns, interactive=True)
with gr.Row():
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
@gr.on(triggers=[model_select.change, dataset_select.change, columns_select.change, task_select.change],
inputs=[model_select, dataset_select, columns_select, task_select],
outputs=[leaderboard_df])
def filter_table(model_id, dataset_id, columns, task):
# filter the table based on task
df = records[(records['task'] == task)]
# filter the table based on the model_id and dataset_id
if model_id:
df = records[(records['model_id'] == model_id)]
if dataset_id:
df = records[(records['dataset_id'] == dataset_id)]
# filter the table based on the columns
df = df[columns]
types = get_types(df)
display_df = get_display_df(df)
return (
gr.update(value=display_df, datatype=types, interactive=False)
) |