File size: 5,290 Bytes
631e505
 
791e69b
631e505
 
 
 
 
 
 
 
 
 
 
 
 
 
791e69b
631e505
 
791e69b
631e505
 
 
 
 
ebb58d8
 
 
 
631e505
791e69b
 
631e505
 
b2821de
631e505
 
 
 
 
 
 
 
 
 
 
 
 
4001fbf
 
631e505
 
 
 
 
 
 
 
 
 
 
 
ebb58d8
 
 
 
 
 
 
 
 
 
631e505
 
ebb58d8
631e505
 
4001fbf
631e505
 
 
 
 
 
4001fbf
631e505
 
 
 
 
 
 
 
 
 
ebb58d8
631e505
 
4001fbf
631e505
 
 
4001fbf
631e505
4001fbf
631e505
4001fbf
 
 
 
 
 
631e505
 
 
4001fbf
631e505
 
 
 
 
 
 
 
 
4001fbf
 
 
 
 
 
 
 
 
 
631e505
ebb58d8
4001fbf
631e505
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import streamlit as st
import pandas as pd
import json
from os.path import split as path_split, splitext as path_splitext

st.set_page_config(
    page_title="PPE Metrics Explorer",
    layout="wide",  # This makes the app use the entire screen width
    initial_sidebar_state="expanded",
)

# Set the title of the app
st.title("PPE Metrics Explorer")

@st.cache_data
def load_data(file_path):
    """
    Load json data from a file.
    """
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def contains_list(column):
    return column.apply(lambda x: isinstance(x, list)).any()

INVERT = {'brier', 'loss'}

SCALE = {'accuracy', 'row-wise pearson', 'confidence_agreement', 'spearman', 'kendalltau', 'arena_under_curve', 'mean_max_score', 'mean_end_score'}

def main():
    # Load the JSON data
    data = load_data('results.json')

    # Extract the list of benchmarks
    benchmarks = list(sorted(data.keys(), key=lambda s: "A" + s if s == "human_preference_v1" else s))

    # Dropdown for selecting benchmark
    selected_benchmark = st.selectbox("Select a Benchmark", benchmarks)

    # Extract data for the selected benchmark
    benchmark_data = data[selected_benchmark]

    # Prepare a list to store records
    records = []

    # Iterate over each model in the selected benchmark
    for model, metrics in benchmark_data.items():

        model_type = "LLM Judge" if model.endswith(".jsonl") else "Reward Model"

        model = path_split(path_splitext(model)[0])[-1]
        # Flatten the metrics dictionary if there are nested metrics
        # For example, in "human_preference_v1", there are subcategories like "overall", "hard_prompt", etc.
        # We'll aggregate these or allow the user to select subcategories as needed
        if isinstance(metrics, dict):
            # If there are nested keys, we can allow the user to select a subcategory
            # For simplicity, let's assume we want to display all nested metrics concatenated
            flattened_metrics = {}
            for subkey, submetrics in metrics.items():
                if isinstance(submetrics, dict):
                    for metric_name, value in submetrics.items():
                        # Create a compound key
                        if metric_name in SCALE:

                            value = 100 * value

                        if metric_name in INVERT:
                            key = f"{subkey} - (1 - {metric_name})"
                            flattened_metrics[key] = 1 - value
                        else:
                            key = f"{subkey} - {metric_name}"
                            flattened_metrics[key] = value
                else:
                    flattened_metrics[subkey] = submetrics

            records.append({
                "Model": model,
                "Type": model_type,
                **flattened_metrics
            })
        else:
            # If metrics are not nested, just add them directly
            records.append({
                "Model": model,
                "Type": model_type,
                "Value": metrics
            })

    # Create a DataFrame
    df = pd.DataFrame(records)

    # Drop columns that contain lists
    df = df.loc[:, ~df.apply(contains_list)]

    if "human" not in selected_benchmark:
        df = df[sorted(df.columns, key=lambda s: s.replace("(1", "l").lower() if s != "Type" else "A")]

    # Set 'Model' as the index
    df.set_index(["Model"], inplace=True)


        # Create two columns: one for spacing and one for the search bar
    col1, col2, col3 = st.columns([1, 1, 2])  # Adjust the ratios as needed
    with col1:
        
        column_search = st.text_input("", placeholder="Search metrics...", key="search")

    with col2:

        model_search = st.text_input("", placeholder="Filter Models (separate criteria with ,) ...", key="search2")

        model_search_crit = model_search.replace(", ", "|").replace(",", "|")

    if column_search:
        # Filter columns that contain the search term (case-insensitive)
        filtered_columns = ["Type"] + [col for col in df.columns if column_search.lower() in col.lower()]
        if filtered_columns:
            df_display = df[filtered_columns]
        else:
            st.warning("No columns match your search.")
            df_display = pd.DataFrame()  # Empty DataFrame
    else:
        # If no search term, display all columns
        df_display = df

    if model_search:

        df_display = df_display[df_display.index.str.contains(model_search_crit, case=False)]

        if len(df_display) == 0:
            st.warning("No models match your filter.")
            df_display = pd.DataFrame()  # Empty DataFrame
        
        

    # Display the DataFrame
    st.dataframe(df_display.sort_values(df_display.columns[1], ascending=False).style.background_gradient(cmap='summer_r', axis=0).format(precision=4)
 if len(df_display) else df_display, use_container_width=True, height=500)

    # Optional: Allow user to download the data as CSV
    csv = df_display.to_csv()
    st.download_button(
        label="Download data as CSV",
        data=csv,
        file_name=f"{selected_benchmark}_metrics.csv",
        mime='text/csv',
    )

if __name__ == "__main__":
    main()