|
import altair as alt |
|
import gradio as gr |
|
import pandas as pd |
|
|
|
from functools import partial |
|
from datasets import load_dataset |
|
|
|
def get_data(): |
|
model_id = "ybelkada/model_cards_correct_tag" |
|
dataset = load_dataset(model_id, split="train").to_pandas() |
|
|
|
|
|
df = pd.DataFrame(dataset) |
|
df["commit_dates"] = pd.to_datetime(df["commit_dates"]) |
|
df = df.sort_values(by="commit_dates") |
|
melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type') |
|
|
|
df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100 |
|
ratio_df = df[['commit_dates', 'ratio']].copy() |
|
return ratio_df, melted_df |
|
|
|
ratio_df, melted_df = get_data() |
|
|
|
def make_plot(plot_type, refresh=False): |
|
global ratio_df, melted_df |
|
|
|
if refresh: |
|
|
|
ratio_df, melted_df = get_data() |
|
|
|
if plot_type == "Total models with missing 'transformers' tag": |
|
highlight = alt.selection(type='single', on='mouseover', |
|
fields=['type'], nearest=True) |
|
|
|
|
|
base = alt.Chart(melted_df).encode( |
|
x=alt.X('commit_dates:T', title='Date'), |
|
y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"), |
|
color='type:N', |
|
) |
|
|
|
points = base.mark_circle().encode( |
|
opacity=alt.value(1), |
|
).add_selection( |
|
highlight |
|
).properties( |
|
width=1200, |
|
height=800, |
|
) |
|
|
|
lines = base.mark_line().encode( |
|
size=alt.condition(~highlight, alt.value(1), alt.value(3)) |
|
) |
|
|
|
return points + lines |
|
else: |
|
highlight = alt.selection(type='single', on='mouseover', |
|
fields=['ratio'], nearest=True) |
|
|
|
base = alt.Chart(ratio_df).encode( |
|
x=alt.X('commit_dates:T', title='Date'), |
|
y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"), |
|
) |
|
|
|
points = base.mark_circle().encode( |
|
opacity=alt.value(1) |
|
).add_selection( |
|
highlight |
|
).properties( |
|
width=1200, |
|
height=800, |
|
) |
|
|
|
lines = base.mark_line().encode( |
|
size=alt.condition(~highlight, alt.value(1), alt.value(3)) |
|
) |
|
|
|
return points + lines |
|
|
|
|
|
with gr.Blocks() as demo: |
|
button = gr.Radio( |
|
label="Plot type", |
|
choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"], |
|
value="Total models with missing 'transformers' tag" |
|
) |
|
refresh_button = gr.Button(value="Fetch latest data") |
|
|
|
plot = gr.Plot(label="Plot") |
|
|
|
button.change(make_plot, inputs=[button], outputs=[plot]) |
|
refresh_button.click(partial(make_plot, refresh=True), inputs=[button], outputs=[plot]) |
|
demo.load(make_plot, inputs=[button], outputs=[plot]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |