import logging import datasets import gradio as gr import pandas as pd import datetime from fetch_utils import (check_dataset_and_get_config, check_dataset_and_get_split) import leaderboard logger = logging.getLogger(__name__) global update_time update_time = datetime.datetime.fromtimestamp(0) def get_records_from_dataset_repo(dataset_id): dataset_config = check_dataset_and_get_config(dataset_id) logger.info(f"Dataset {dataset_id} has configs {dataset_config}") dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0]) logger.info(f"Dataset {dataset_id} has splits {dataset_split}") try: ds = datasets.load_dataset(dataset_id, dataset_config[0])[dataset_split[0]] df = ds.to_pandas() return df except Exception as e: logger.warning( f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}" ) return pd.DataFrame() def get_model_ids(ds): logging.info(f"Dataset {ds} column names: {ds['model_id']}") models = ds["model_id"].tolist() # return unique elements in the list model_ids model_ids = list(set(models)) model_ids.insert(0, "Any") return model_ids def get_dataset_ids(ds): logging.info(f"Dataset {ds} column names: {ds['dataset_id']}") datasets = ds["dataset_id"].tolist() dataset_ids = list(set(datasets)) dataset_ids.insert(0, "Any") return dataset_ids def get_types(ds): # set types for each column types = [str(t) for t in ds.dtypes.to_list()] types = [t.replace("object", "markdown") for t in types] types = [t.replace("float64", "number") for t in types] types = [t.replace("int64", "number") for t in types] return types def get_display_df(df): # style all elements in the model_id column display_df = df.copy() columns = display_df.columns.tolist() if "model_id" in columns: display_df["model_id"] = display_df["model_id"].apply( lambda x: f'🔗{x}' ) # style all elements in the dataset_id column if "dataset_id" in columns: display_df["dataset_id"] = display_df["dataset_id"].apply( lambda x: f'🔗{x}' ) # style all elements in the report_link column if "report_link" in columns: display_df["report_link"] = display_df["report_link"].apply( lambda x: f'🔗{x}' ) return display_df def get_demo(leaderboard_tab): global update_time update_time = datetime.datetime.now() logger.info("Loading leaderboard records") leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD) records = leaderboard.records model_ids = get_model_ids(records) dataset_ids = get_dataset_ids(records) column_names = records.columns.tolist() default_columns = ["model_id", "dataset_id", "total_issues", "report_link"] default_df = records[default_columns] # extract columns selected types = get_types(default_df) display_df = get_display_df(default_df) # the styled dataframe to display with gr.Row(): task_select = gr.Dropdown( label="Task", choices=["text_classification", "tabular"], value="text_classification", interactive=True, ) model_select = gr.Dropdown( label="Model id", choices=model_ids, value=model_ids[0], interactive=True ) dataset_select = gr.Dropdown( label="Dataset id", choices=dataset_ids, value=dataset_ids[0], interactive=True, ) with gr.Row(): columns_select = gr.CheckboxGroup( label="Show columns", choices=column_names, value=default_columns, interactive=True, ) with gr.Row(): leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False) def update_leaderboard_records(model_id, dataset_id, columns, task): global update_time if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10): return gr.update() update_time = datetime.datetime.now() logger.info("Updating leaderboard records") leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD) return filter_table(model_id, dataset_id, columns, task) leaderboard_tab.select( fn=update_leaderboard_records, inputs=[model_select, dataset_select, columns_select, task_select], outputs=[leaderboard_df]) @gr.on( triggers=[ model_select.change, dataset_select.change, columns_select.change, task_select.change, ], inputs=[model_select, dataset_select, columns_select, task_select], outputs=[leaderboard_df], ) def filter_table(model_id, dataset_id, columns, task): logger.info("Filtering leaderboard records") records = leaderboard.records # filter the table based on task df = records[(records["task"] == task)] # filter the table based on the model_id and dataset_id if model_id and model_id != "Any": df = df[(df["model_id"] == model_id)] if dataset_id and dataset_id != "Any": df = df[(df["dataset_id"] == dataset_id)] # filter the table based on the columns df = df[columns] types = get_types(df) display_df = get_display_df(df) return gr.update(value=display_df, datatype=types, interactive=False)