import gradio as gr from typing import List, Tuple import plotly.express as px from huggingface_hub import snapshot_download import os import pdb import logging import pandas as pd from config import LOCAL_RESULTS_DIR, CITATION_BUTTON_TEXT, DatasetHelper, ModelHelper from parsing import read_all_configs # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ # logging.FileHandler("app.log"), logging.StreamHandler() ], ) logger = logging.getLogger(__name__) try: print("Saving results locally at:", LOCAL_RESULTS_DIR) snapshot_download( repo_id="g8a9/fair-asr-results", local_dir=LOCAL_RESULTS_DIR, repo_type="dataset", tqdm_class=None, etag_timeout=30, ignore_patterns=["*samples*", "*transcripts*"], token=os.environ.get("TOKEN"), ) except Exception as e: raise e def format_dataframe(df, times_100=False): if times_100: df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x)) else: df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x)) return df def _build_models_with_nan_md(models_with_nan): model_markups = [f"*{m}*" for m in models_with_nan] return f""" We are currently hiding the results of {', '.join(model_markups)} because they don't support all languages. """ def build_components(show_common_langs, selected_datasets: List[str]): aggregated_df, lang_dfs, barplot_figs, models_with_nan = _populate_components( show_common_langs, selected_datasets ) models_with_nan_md = _build_models_with_nan_md(models_with_nan) return ( gr.DataFrame(format_dataframe(aggregated_df)), gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)), gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)), gr.Plot(barplot_figs[0]), gr.Plot(barplot_figs[1]), gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0), ) def _populate_components( show_common_langs: bool, selected_datasets: List[str], contrast_type: str = "F-M" ) -> Tuple[pd.DataFrame, List[pd.DataFrame], List[px.bar], List[str]]: results = read_all_configs(contrast_type) if show_common_langs: common_langs = model_h.get_common_langs() logger.info(f"Common langs: {common_langs}") results = results[results["Language"].isin(common_langs)] missing_langs = ( results[results.isna().any(axis=1)] .groupby("Model")["Language"] .apply(list) .to_dict() ) for model, langs in missing_langs.items(): logger.info( f"Model {model} is missing results for languages: {', '.join(langs)}" ) models_with_nan = results[results.isna().any(axis=1)]["Model"].unique().tolist() logger.info(f"Models with NaN values: {models_with_nan}") results = results[~results["Model"].isin(models_with_nan)] type_dfs = list() lang_dfs = list() barplot_figs = list() for type, type_df in results.groupby("Type"): # Aggregate main aggregated_df = type_df.pivot_table( index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum(), ) aggregated_df = aggregated_df.rename(columns={"Gap": f"Gap ({type})"}) type_dfs.append(aggregated_df) best_model = aggregated_df.index[0] top_3_models = aggregated_df.index[:3].tolist() # Aggregate by language lang_df = type_df.pivot_table( index="Model", values="Gap", columns="Language", ).reset_index() lang_dfs.append(lang_df) # Create plot type_df["Gap"] = type_df["Gap"] * 100 barplot_fig = px.bar( type_df.loc[results["Model"].isin(top_3_models)], x="Language", y="Gap", color="Model", title=f"{type}: Gaps by Language and Model (top 3, sorted by the best model)", labels={ "Gap": f"{contrast_type} Gap (%)", "Language": "Language", "Model": "Model", }, barmode="group", ) lang_order = ( lang_df.set_index("Model") .loc[best_model] .sort_values(ascending=False) .index ) logger.info(f"Lang order: {lang_order}") barplot_fig.update_layout( xaxis={"categoryorder": "array", "categoryarray": lang_order} ) barplot_figs.append(barplot_fig) # pdb.set_trace() aggregated_df = pd.concat(type_dfs, axis=1, join="inner") aggregated_df["Avg"] = aggregated_df.mean(axis=1) aggregated_df = aggregated_df.sort_values("Avg").reset_index() # lang_df = results.pivot_table( # index="Model", # values="Gap", # columns="Language", # ).reset_index() # results["Gap"] = results["Gap"] * 100 # barplot_fig = px.bar( # results.loc[results["Model"].isin(top_3_models)], # x="Language", # y="Gap", # color="Model", # title="Gaps by Language and Model (top 3, sorted by the best model)", # labels={ # "Gap": "Sum of Absolute Gaps (%)", # "Language": "Language", # "Model": "Model", # }, # barmode="group", # ) # lang_order = ( # lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index # ) # logger.info(f"Lang order: {lang_order}") # barplot_fig.update_layout( # xaxis={"categoryorder": "array", "categoryarray": lang_order} # ) return aggregated_df, lang_dfs, barplot_figs, models_with_nan dataset_h = DatasetHelper() model_h = ModelHelper() with gr.Blocks() as fm_interface: aggregated_df, lang_dfs, barplot_figs, model_with_nan = _populate_components( show_common_langs=False, selected_datasets=dataset_h.get_dataset_names() ) model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan)) gr.Markdown("### Sum of Absolute Gaps ⬇️") aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df)) gr.Markdown("#### Read: gaps by language") lang_df_comp_0 = gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)) barplot_fig_comp_0 = gr.Plot(barplot_figs[0]) gr.Markdown("#### Spontaneous: gaps by language") lang_df_comp_1 = gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)) barplot_fig_comp_1 = gr.Plot(barplot_figs[1]) ################### # LIST MAIN TABS ################### tabs = [fm_interface] titles = ["F-M Setup"] banner = """
Twists Banner
""" ################### # MAIN INTERFACE ################### with gr.Blocks() as demo: gr.HTML(banner) with gr.Row() as config_row: show_common_langs = gr.CheckboxGroup( choices=["Show only common languages"], label="Main configuration", ) datasets_names = dataset_h.get_dataset_names() include_datasets = gr.CheckboxGroup( choices=datasets_names, label="Include datasets", value=datasets_names, interactive=False, ) show_common_langs.input( build_components, inputs=[show_common_langs, include_datasets], outputs=[ aggregated_df_comp, lang_df_comp_0, lang_df_comp_1, barplot_fig_comp_0, barplot_fig_comp_1, model_with_nans_md, ], ) gr.TabbedInterface(tabs, titles) gr.Markdown( """ ### Citation If you find these results useful, please cite the following paper: """ ) gr.Markdown( f"""``` {CITATION_BUTTON_TEXT}""" ) if __name__ == "__main__": demo.launch()