File size: 4,956 Bytes
13e8963 0eec09c 13e8963 e698b42 dfae691 13e8963 c28665f 13e8963 c28665f 13e8963 c28665f 13e8963 bb256f3 13e8963 1b11ded 13e8963 1b11ded 0eec09c bb256f3 0eec09c 13e8963 c28665f 13e8963 ca2e2c2 13e8963 c28665f 13e8963 c28665f 13e8963 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr # type: ignore
import plotly.express as px # type: ignore
from backend.data import load_cot_data, is_visible_model
from backend.envs import API, REPO_ID, TOKEN
logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png"
LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>'
TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}'
INTRODUCTION_TEXT = """
Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard).
"""
def restart_space():
API.restart_space(repo_id=REPO_ID, token=TOKEN)
try:
df_cot_err, df_cot_regimes = load_cot_data()
except Exception as err:
print(err)
# sleep for 10 seconds before restarting the space
import time
time.sleep(10)
restart_space()
def plot_evals_init(model_id, plotly_mode, request: gr.Request):
if request and "model" in request.query_params:
model_param = request.query_params["model"]
if model_param in df_cot_err.model.to_list():
model_id = model_param
return plot_evals(model_id, plotly_mode)
def plot_evals(model_id, plotly_mode):
df = df_cot_err.copy()
df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
df["visibility"] = df_cot_err.model.apply(is_visible_model) | df.selected.eq("selected")
#df.sort_values(["selected", "model"], inplace=True, ascending=True) # has currently no effect with px.scatter
template = "plotly_dark" if plotly_mode=="dark" else "plotly"
fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
facet_col="task", facet_col_wrap=3,
category_orders={"selected": ["selected", "-"]},
color_discrete_sequence=["Orange", "Gray"],
template=template,
error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
custom_data=['visibility'],
width=1200, height=700)
# TODO: doesn't work, needs to be fixed
fig.update_traces(
visible="legendonly",
selector=dict(visibility=False)
)
fig.update_layout(
title={"automargin": True},
)
return fig, model_id
def styled_model_table_init(model_id, request: gr.Request):
if request and "model" in request.query_params:
model_param = request.query_params["model"]
if model_param in df_cot_regimes.model.to_list():
model_id = model_param
return styled_model_table(model_id)
def styled_model_table(model_id):
def make_pretty(styler):
styler.hide(axis="index")
styler.format(precision=1),
styler.background_gradient(
axis=None,
subset=["acc_base", "acc_cot"],
vmin=20, vmax=100, cmap="YlGnBu"
)
styler.background_gradient(
axis=None,
subset=["acc_gain"],
vmin=-20, vmax=20, cmap="coolwarm"
)
styler.set_table_styles({
'task': [{'selector': '',
'props': [('font-weight', 'bold')]}],
'B': [{'selector': 'td',
'props': 'color: blue;'}]
}, overwrite=False)
return styler
df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of',
'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'acc_gain']]
df_cot_model = df_cot_model \
.rename(columns={"temperature": "temp"}) \
.replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \
.sort_values(["task", "cot_chain"]) \
.reset_index(drop=True)
return df_cot_model.style.pipe(make_pretty)
demo = gr.Blocks()
with demo:
gr.HTML(TITLE)
gr.Markdown(INTRODUCTION_TEXT)
with gr.Row():
model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
plotly_mode = gr.Radio(["dark","light"], value="light", label="Plot theme", scale=1)
submit = gr.Button("Update", scale=1)
table = gr.DataFrame()
plot = gr.Plot(label="evals")
submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list])
submit.click(styled_model_table, model_list, table)
demo.load(plot_evals_init, [model_list, plotly_mode], [plot, model_list])
demo.load(styled_model_table_init, model_list, table)
demo.launch() |