Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import random | |
import plotly.express as px | |
from huggingface_hub import snapshot_download | |
import os | |
import logging | |
from config import ( | |
SETUPS, | |
LOCAL_RESULTS_DIR, | |
CITATION_BUTTON_TEXT, | |
CITATION_BUTTON_LABEL, | |
) | |
from parsing import read_all_configs, get_common_langs | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[ | |
# logging.FileHandler("app.log"), | |
logging.StreamHandler() | |
], | |
) | |
logger = logging.getLogger(__name__) | |
try: | |
print("Saving results locally at:", LOCAL_RESULTS_DIR) | |
snapshot_download( | |
repo_id="g8a9/fair-asr-results", | |
local_dir=LOCAL_RESULTS_DIR, | |
repo_type="dataset", | |
tqdm_class=None, | |
etag_timeout=30, | |
ignore_patterns=["*samples*", "*transcripts*"], | |
token=os.environ.get("TOKEN"), | |
) | |
except Exception as e: | |
raise e | |
def format_dataframe(df, times_100=False): | |
if times_100: | |
df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x)) | |
else: | |
df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x)) | |
return df | |
def _build_models_with_nan_md(models_with_nan): | |
model_markups = [f"*{m}*" for m in models_with_nan] | |
return f""" | |
We are currently hiding the results of {', '.join(model_markups)} because they don't support all languages. | |
""" | |
def build_components(show_common_langs): | |
aggregated_df, lang_df, barplot_fig, models_with_nan = _populate_components( | |
show_common_langs | |
) | |
models_with_nan_md = _build_models_with_nan_md(models_with_nan) | |
return ( | |
gr.DataFrame(format_dataframe(aggregated_df)), | |
gr.DataFrame(format_dataframe(lang_df, times_100=True)), | |
gr.Plot(barplot_fig), | |
gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0), | |
) | |
def _populate_components(show_common_langs): | |
fm = SETUPS[0] | |
setup = fm["majority_group"] + "_" + fm["minority_group"] | |
results = read_all_configs(setup) | |
if show_common_langs: | |
common_langs = get_common_langs() | |
logger.info(f"Common langs: {common_langs}") | |
results = results[results["Language"].isin(common_langs)] | |
missing_langs = ( | |
results[results.isna().any(axis=1)] | |
.groupby("Model")["Language"] | |
.apply(list) | |
.to_dict() | |
) | |
for model, langs in missing_langs.items(): | |
logger.info( | |
f"Model {model} is missing results for languages: {', '.join(langs)}" | |
) | |
models_with_nan = results[results.isna().any(axis=1)]["Model"].unique().tolist() | |
logger.info(f"Models with NaN values: {models_with_nan}") | |
results = results[~results["Model"].isin(models_with_nan)] | |
aggregated_df = ( | |
results.pivot_table( | |
index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum() | |
) | |
.reset_index() | |
.sort_values("Gap") | |
) | |
best_model = aggregated_df.iloc[0]["Model"] | |
top_3_models = aggregated_df["Model"].head(3).tolist() | |
# main_df = gr.DataFrame(format_dataframe(model_results)) | |
lang_df = results.pivot_table( | |
index="Model", | |
values="Gap", | |
columns="Language", | |
).reset_index() | |
# lang_df = gr.DataFrame(format_dataframe(lang_results, times_100=True)) | |
# gr.Plot(fig1) | |
results["Gap"] = results["Gap"] * 100 | |
barplot_fig = px.bar( | |
results.loc[results["Model"].isin(top_3_models)], | |
x="Language", | |
y="Gap", | |
color="Model", | |
title="Gaps by Language and Model (top 3, sorted by the best model)", | |
labels={ | |
"Gap": "Sum of Absolute Gaps (%)", | |
"Language": "Language", | |
"Model": "Model", | |
}, | |
barmode="group", | |
) | |
lang_order = ( | |
lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index | |
) | |
logger.info(f"Lang order: {lang_order}") | |
barplot_fig.update_layout( | |
xaxis={"categoryorder": "array", "categoryarray": lang_order} | |
) | |
return aggregated_df, lang_df, barplot_fig, models_with_nan | |
with gr.Blocks() as fm_interface: | |
aggregated_df, lang_df, barplot_fig, model_with_nan = _populate_components( | |
show_common_langs=False | |
) | |
model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan)) | |
gr.Markdown("### Sum of Absolute Gaps ⬇️") | |
aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df)) | |
gr.Markdown("#### F-M gaps by language") | |
lang_df_comp = gr.DataFrame(format_dataframe(lang_df, times_100=True)) | |
barplot_fig_comp = gr.Plot(barplot_fig) | |
################### | |
# LIST MAIN TABS | |
################### | |
tabs = [fm_interface] | |
titles = ["F-M Setup"] | |
banner = """ | |
<style> | |
.full-width-image { | |
width: 100%; | |
height: auto; | |
margin: 0; | |
padding: 0; | |
} | |
</style> | |
<div> | |
<img src="https://huggingface.co/spaces/g8a9/fair-asr-leaderboard/raw/main/twists_banner.png" alt="Twists Banner" class="full-width-image"> | |
</div> | |
""" | |
################### | |
# MAIN INTERFACE | |
################### | |
with gr.Blocks() as demo: | |
gr.HTML(banner) | |
with gr.Row() as config_row: | |
show_common_langs = gr.CheckboxGroup( | |
choices=["Show only common languages"], | |
label="Main configuration", | |
) | |
include_datasets = gr.CheckboxGroup( | |
choices=["Mozilla CV 17"], | |
label="Include datasets", | |
value=["Mozilla CV 17"], | |
interactive=False, | |
) | |
show_common_langs.input( | |
build_components, | |
inputs=[show_common_langs], | |
outputs=[ | |
aggregated_df_comp, | |
lang_df_comp, | |
barplot_fig_comp, | |
model_with_nans_md, | |
], | |
) | |
gr.TabbedInterface(tabs, titles) | |
gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
label=CITATION_BUTTON_LABEL, | |
max_lines=6, | |
show_copy_button=True, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |