import gradio as gr import datasets import logging from fetch_utils import check_dataset_and_get_config, check_dataset_and_get_split def get_records_from_dataset_repo(dataset_id): dataset_config = check_dataset_and_get_config(dataset_id) logging.info(f"Dataset {dataset_id} has configs {dataset_config}") dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0]) logging.info(f"Dataset {dataset_id} has splits {dataset_split}") try: ds = datasets.load_dataset(dataset_id, dataset_config[0])[dataset_split[0]] df = ds.to_pandas() return df except Exception as e: logging.warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") return None def get_model_ids(ds): logging.info(f"Dataset {ds} column names: {ds['model_id']}") models = ds['model_id'].tolist() # return unique elements in the list model_ids model_ids = list(set(models)) return model_ids def get_dataset_ids(ds): logging.info(f"Dataset {ds} column names: {ds['dataset_id']}") datasets = ds['dataset_id'].tolist() dataset_ids = list(set(datasets)) return dataset_ids def get_types(ds): # set all types for each column types = [str(t) for t in ds.dtypes.to_list()] types = [t.replace('object', 'markdown') for t in types] types = [t.replace('float64', 'number') for t in types] types = [t.replace('int64', 'number') for t in types] return types def get_display_df(df): # style all elements in the model_id column display_df = df.copy() columns = display_df.columns.tolist() if 'model_id' in columns: display_df['model_id'] = display_df['model_id'].apply(lambda x: f'
🔗{x}
') # style all elements in the dataset_id column if 'dataset_id' in columns: display_df['dataset_id'] = display_df['dataset_id'].apply(lambda x: f'🔗{x}
') # style all elements in the report_link column if 'report_link' in columns: display_df['report_link'] = display_df['report_link'].apply(lambda x: f'🔗{x}
') return display_df def get_demo(): records = get_records_from_dataset_repo('ZeroCommand/test-giskard-report') model_ids = get_model_ids(records) dataset_ids = get_dataset_ids(records) column_names = records.columns.tolist() default_columns = ['model_id', 'dataset_id', 'total_issues', 'report_link'] # set the default columns to show default_df = records[default_columns] types = get_types(default_df) display_df = get_display_df(default_df) with gr.Row(): task_select = gr.Dropdown(label='Task', choices=['text_classification', 'tabular'], value='text_classification', interactive=True) model_select = gr.Dropdown(label='Model id', choices=model_ids, interactive=True) dataset_select = gr.Dropdown(label='Dataset id', choices=dataset_ids, interactive=True) with gr.Row(): columns_select = gr.CheckboxGroup(label='Show columns', choices=column_names, value=default_columns, interactive=True) with gr.Row(): leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False) @gr.on(triggers=[model_select.change, dataset_select.change, columns_select.change, task_select.change], inputs=[model_select, dataset_select, columns_select, task_select], outputs=[leaderboard_df]) def filter_table(model_id, dataset_id, columns, task): # filter the table based on task df = records[(records['task'] == task)] # filter the table based on the model_id and dataset_id if model_id: df = records[(records['model_id'] == model_id)] if dataset_id: df = records[(records['dataset_id'] == dataset_id)] # filter the table based on the columns df = df[columns] types = get_types(df) display_df = get_display_df(df) return ( gr.update(value=display_df, datatype=types, interactive=False) )