|
import gradio as gr |
|
import os |
|
from huggingface_hub import HfApi, snapshot_download |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from datasets import load_dataset |
|
from src.utils import load_all_data, prep_df, sort_by_category |
|
from src.md import ABOUT_TEXT, TOP_TEXT |
|
from src.css import custom_css |
|
import numpy as np |
|
|
|
api = HfApi() |
|
|
|
COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN") |
|
evals_repo = "alrope/href_results" |
|
|
|
eval_set_repo = "alrope/test_dev" |
|
local_result_dir = "./results/" |
|
|
|
def restart_space(): |
|
api.restart_space(repo_id="alrope/href", token=COLLAB_TOKEN) |
|
|
|
print("Pulling evaluation results") |
|
repo = snapshot_download( |
|
local_dir=local_result_dir, |
|
ignore_patterns=[], |
|
repo_id=evals_repo, |
|
use_auth_token=COLLAB_TOKEN, |
|
tqdm_class=None, |
|
etag_timeout=30, |
|
repo_type="dataset", |
|
) |
|
|
|
href_data_greedy = prep_df(load_all_data(local_result_dir, subdir="temperature=0.0")) |
|
href_data_nongreedy = prep_df(load_all_data(local_result_dir, subdir="temperature=1.0")) |
|
|
|
|
|
col_types_href = ["number"] + ["markdown"] + ["number"] * int((len(href_data_greedy.columns) - 1) / 2) |
|
col_types_href_hidden = ["number"] + ["markdown"] + ["number"] * (len(href_data_greedy.columns) - 1) |
|
categories = ['Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify', "Reasoning Over Numerical Data", "Multi-Document Synthesis", "Fact Checking or Attributed QA"] |
|
|
|
|
|
|
|
eval_set = load_dataset(eval_set_repo, use_auth_token=COLLAB_TOKEN, split="dev") |
|
def random_sample(r: gr.Request, category): |
|
if category is None or category == []: |
|
sample_index = np.random.randint(0, len(eval_set) - 1) |
|
sample = eval_set[sample_index] |
|
else: |
|
if isinstance(category, str): |
|
category = [category] |
|
|
|
eval_set_filtered = eval_set.filter(lambda x: x["category"] in category) |
|
sample_index = np.random.randint(0, len(eval_set_filtered) - 1) |
|
sample = eval_set_filtered[sample_index] |
|
|
|
markdown_text = '\n\n'.join([f"**{key}**:\n\n{value}" for key, value in sample.items()]) |
|
return markdown_text |
|
|
|
subsets = eval_set.unique("category") |
|
|
|
|
|
def regex_table(dataframe, regex, selected_category, style=True): |
|
""" |
|
Takes a model name as a regex, then returns only the rows that has that in it. |
|
""" |
|
dataframe = sort_by_category(dataframe, selected_category) |
|
|
|
|
|
regex_list = [x.strip() for x in regex.split(",")] |
|
|
|
combined_regex = '|'.join(regex_list) |
|
|
|
|
|
data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)] |
|
|
|
data.reset_index(drop=True, inplace=True) |
|
|
|
if style: |
|
|
|
format_dict = {col: "{:.1f}" for col in data.columns if col not in ['Average', 'Model', 'Rank']} |
|
format_dict['Average'] = "{:.2f}" |
|
data = data.style.format(format_dict, na_rep='').set_properties(**{'text-align': 'right'}) |
|
return data |
|
|
|
|
|
total_models = len(regex_table(href_data_greedy.copy(), "", "Average", style=False).values) |
|
|
|
with gr.Blocks(css=custom_css) as app: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=6): |
|
gr.Markdown(TOP_TEXT.format(str(total_models))) |
|
with gr.Column(scale=4): |
|
|
|
|
|
|
|
gr.Markdown(""" |
|
![](file/src/logo.png) |
|
""") |
|
with gr.Tabs(elem_classes="tab-buttons") as tabs: |
|
with gr.TabItem("๐ HREF Leaderboard"): |
|
with gr.Row(): |
|
search_1 = gr.Textbox(label="Model Search (delimit with , )", |
|
|
|
show_label=True) |
|
category_selector_1 = gr.Dropdown(categories, label="Sorted By", value="Average", multiselect=False, show_label=True) |
|
with gr.Row(): |
|
|
|
rewardbench_table_hidden = gr.Dataframe( |
|
href_data_greedy.values, |
|
datatype=col_types_href_hidden, |
|
headers=href_data_greedy.columns.tolist(), |
|
visible=False, |
|
) |
|
rewardbench_table = gr.Dataframe( |
|
regex_table(href_data_greedy.copy(), "", "Average"), |
|
datatype=col_types_href, |
|
headers=href_data_greedy.columns.tolist(), |
|
elem_id="href_data_greedy", |
|
interactive=False, |
|
max_height=1000, |
|
) |
|
with gr.TabItem("Non-Greedy"): |
|
with gr.Row(): |
|
search_2 = gr.Textbox(label="Model Search (delimit with , )", |
|
|
|
show_label=True) |
|
category_selector_2 = gr.Dropdown(categories, label="Sorted By", value="Average", multiselect=False, show_label=True) |
|
with gr.Row(): |
|
|
|
rewardbench_table_hidden_nongreedy = gr.Dataframe( |
|
href_data_nongreedy.values, |
|
datatype=col_types_href_hidden, |
|
headers=href_data_nongreedy.columns.tolist(), |
|
visible=False, |
|
) |
|
rewardbench_table_nongreedy = gr.Dataframe( |
|
regex_table(href_data_nongreedy.copy(), "", "Average"), |
|
datatype=col_types_href, |
|
headers=href_data_nongreedy.columns.tolist(), |
|
elem_id="href_data_nongreedy", |
|
interactive=False, |
|
max_height=1000, |
|
) |
|
with gr.TabItem("About"): |
|
with gr.Row(): |
|
gr.Markdown(ABOUT_TEXT) |
|
|
|
with gr.TabItem("Dataset Viewer"): |
|
with gr.Row(): |
|
|
|
gr.Markdown("""## Random Dataset Sample Viewer |
|
Warning, refusals, XSTest, and donotanswer datasets have sensitive content.""") |
|
subset_selector = gr.Dropdown(subsets, label="Category", value=None, multiselect=True) |
|
button = gr.Button("Show Random Sample") |
|
|
|
with gr.Row(): |
|
sample_display = gr.Markdown("{sampled data loads here}") |
|
|
|
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display]) |
|
|
|
search_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table) |
|
category_selector_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table) |
|
search_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy) |
|
category_selector_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy) |
|
|
|
with gr.Row(): |
|
with gr.Accordion("๐ Citation", open=False): |
|
citation_button = gr.Textbox( |
|
value=r"""@misc{RewardBench, |
|
title={RewardBench: Evaluating Reward Models for Language Modeling}, |
|
author={Lambert, Nathan and Pyatkin, Valentina and Morrison, Jacob and Miranda, LJ and Lin, Bill Yuchen and Chandu, Khyathi and Dziri, Nouha and Kumar, Sachin and Zick, Tom and Choi, Yejin and Smith, Noah A. and Hajishirzi, Hannaneh}, |
|
year={2024}, |
|
howpublished={\url{https://huggingface.co/spaces/allenai/reward-bench} |
|
}""", |
|
lines=7, |
|
label="Copy the following to cite these results.", |
|
elem_id="citation-button", |
|
show_copy_button=True, |
|
) |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(restart_space, "interval", seconds=10800) |
|
scheduler.start() |
|
app.launch(allowed_paths=['src/']) |
|
|