File size: 9,093 Bytes
507a14d 9ceb843 e5d5995 8e499f4 df04a09 ab74236 bbe05a0 507a14d df04a09 ab74236 7f5f365 df04a09 9ceb843 e5d5995 7f5f365 507a14d 9ceb843 df04a09 9ceb843 e4cd4cd 9ceb843 507a14d df04a09 6ce351e df04a09 63c5ebf df04a09 507a14d 8e499f4 df04a09 e5d5995 df04a09 e5d5995 ab74236 8e499f4 df04a09 e5d5995 4a1518a df04a09 8799e00 df04a09 8799e00 31bff5a 8799e00 9f4ce43 4a1518a df04a09 63c5ebf df04a09 9f4ce43 8799e00 874c0c9 df04a09 8799e00 bbe05a0 31bff5a 507a14d ca662db 874c0c9 ca662db 31bff5a 908984c 31bff5a 9ceb843 df04a09 31bff5a f89f357 df04a09 557b080 9ceb843 06fd8bd 31bff5a df04a09 06fd8bd 9ceb843 31bff5a df04a09 59b52cf 8799e00 91cb993 9ceb843 8e499f4 92c7f09 df04a09 8e499f4 e5d5995 8799e00 df04a09 91cb993 31bff5a bbe05a0 149a173 17f167a bd17252 149a173 bbe05a0 e5d5995 9ceb843 e5d5995 c8a4819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import gradio as gr
import os
from huggingface_hub import HfApi, snapshot_download
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from src.utils import load_all_data, prep_df, sort_by_category
from src.md import ABOUT_TEXT, TOP_TEXT
from src.css import custom_css
import numpy as np
api = HfApi()
COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN")
evals_repo = "alrope/href_results"
eval_set_repo = "allenai/href_validation"
local_result_dir = "./results/"
def restart_space():
api.restart_space(repo_id="allenai/href", token=COLLAB_TOKEN)
print("Pulling evaluation results")
repo = snapshot_download(
local_dir=local_result_dir,
ignore_patterns=[],
repo_id=evals_repo,
use_auth_token=COLLAB_TOKEN,
tqdm_class=None,
etag_timeout=30,
repo_type="dataset",
)
href_data_greedy = prep_df(load_all_data(local_result_dir, subdir="temperature=0.0"))
href_data_nongreedy = prep_df(load_all_data(local_result_dir, subdir="temperature=1.0"))
col_types_href = ["number"] + ["markdown"] + ["number"] * int((len(href_data_greedy.columns) - 1) / 2)
col_types_href_hidden = ["number"] + ["markdown"] + ["number"] * (len(href_data_greedy.columns) - 1)
categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify', "Reasoning Over Numerical Data", "Multi-Document Synthesis", "Fact Checking or Attributed QA"]
# categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify']
# for showing random samples
eval_set = load_dataset(eval_set_repo, use_auth_token=COLLAB_TOKEN, split="dev")
def random_sample(r: gr.Request, category):
if category is None or category == []:
sample_index = np.random.randint(0, len(eval_set) - 1)
sample = eval_set[sample_index]
else: # filter by category (can be list)
if isinstance(category, str):
category = [category]
# filter down dataset to only include the category(s)
eval_set_filtered = eval_set.filter(lambda x: x["category"] in category)
sample_index = np.random.randint(0, len(eval_set_filtered) - 1)
sample = eval_set_filtered[sample_index]
markdown_text = '\n\n'.join([f"**{key}**:\n\n{value}" for key, value in sample.items()])
return markdown_text
subsets = eval_set.unique("category")
def regex_table(dataframe, regex, selected_category, style=True):
"""
Takes a model name as a regex, then returns only the rows that has that in it.
"""
dataframe = sort_by_category(dataframe, selected_category)
# Split regex statement by comma and trim whitespace around regexes
regex_list = [x.strip() for x in regex.split(",")]
# Join the list into a single regex pattern with '|' acting as OR
combined_regex = '|'.join(regex_list)
# Filter the dataframe such that 'model' contains any of the regex patterns
data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)]
data.reset_index(drop=True, inplace=True)
if style:
# Format for different columns
format_dict = {col: "{:.1f}" for col in data.columns if col not in ['Average', 'Model', 'Rank', '95% CI']}
format_dict['Average'] = "{:.2f}"
data = data.style.format(format_dict, na_rep='').set_properties(**{'text-align': 'right'})
return data
total_models = len(regex_table(href_data_greedy.copy(), "", "Average", style=False).values)
with gr.Blocks(css=custom_css) as app:
# create tabs for the app, moving the current table to one titled "rewardbench" and the benchmark_text to a tab called "About"
with gr.Row():
with gr.Column(scale=8):
gr.Markdown(TOP_TEXT.format(str(total_models)))
with gr.Column(scale=2):
# search = gr.Textbox(label="Model Search (delimit with , )", placeholder="Regex search for a model")
# filter_button = gr.Checkbox(label="Include AI2 training runs (or type ai2 above).", interactive=True)
# img = gr.Image(value="https://private-user-images.githubusercontent.com/10695622/310698241-24ed272a-0844-451f-b414-fde57478703e.png", width=500)
gr.Markdown("""
<img src="file/src/logo.png" height="130">
""")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏆 HREF Leaderboard"):
with gr.Row():
search_1 = gr.Textbox(label="Model Search (delimit with , )",
# placeholder="Model Search (delimit with , )",
show_label=True)
category_selector_1 = gr.Dropdown(categories, label="Sorted By", value="Average", multiselect=False, show_label=True, elem_id="category_selector", elem_classes="category_selector_class")
with gr.Row():
# reference data
rewardbench_table_hidden = gr.Dataframe(
href_data_greedy.values,
datatype=col_types_href_hidden,
headers=href_data_greedy.columns.tolist(),
visible=False,
)
rewardbench_table = gr.Dataframe(
regex_table(href_data_greedy.copy(), "", "Average"),
datatype=col_types_href,
headers=href_data_greedy.columns.tolist(),
elem_id="href_data_greedy",
interactive=False,
height=1000,
)
# with gr.TabItem("Non-Greedy"):
# with gr.Row():
# search_2 = gr.Textbox(label="Model Search (delimit with , )",
# # placeholder="Model Search (delimit with , )",
# show_label=True)
# category_selector_2 = gr.Dropdown(categories, label="Sorted By", value="Average",
# multiselect=False, show_label=True, elem_id="category_selector")
# with gr.Row():
# # reference data
# rewardbench_table_hidden_nongreedy = gr.Dataframe(
# href_data_nongreedy.values,
# datatype=col_types_href_hidden,
# headers=href_data_nongreedy.columns.tolist(),
# visible=False,
# )
# rewardbench_table_nongreedy = gr.Dataframe(
# regex_table(href_data_nongreedy.copy(), "", "Average"),
# datatype=col_types_href,
# headers=href_data_nongreedy.columns.tolist(),
# elem_id="href_data_nongreedy",
# interactive=False,
# height=1000,
# )
with gr.TabItem("About"):
with gr.Row():
gr.Markdown(ABOUT_TEXT)
with gr.TabItem("Dataset Viewer"):
with gr.Row():
# loads one sample
gr.Markdown("""## Random Dataset Sample Viewer""")
subset_selector = gr.Dropdown(subsets, label="Category", value=None, multiselect=True)
button = gr.Button("Show Random Sample")
with gr.Row():
sample_display = gr.Markdown("{sampled data loads here}")
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
search_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
category_selector_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
# search_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)
# category_selector_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)
with gr.Row():
with gr.Accordion("📚 Citation", open=False):
citation_button = gr.Textbox(
value=r"""@misc{RewardBench,
title={RewardBench: Evaluating Reward Models for Language Modeling},
author={Lambert, Nathan and Pyatkin, Valentina and Morrison, Jacob and Miranda, LJ and Lin, Bill Yuchen and Chandu, Khyathi and Dziri, Nouha and Kumar, Sachin and Zick, Tom and Choi, Yejin and Smith, Noah A. and Hajishirzi, Hannaneh},
year={2024},
howpublished={\url{https://huggingface.co/spaces/allenai/reward-bench}
}""",
lines=7,
label="Copy the following to cite these results.",
elem_id="citation-button",
show_copy_button=True,
)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=10800) # restarted every 3h
scheduler.start()
app.launch(allowed_paths=['src/']) # had .queue() before launch before... not sure if that's necessary
|