File size: 4,892 Bytes
e284167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Given a directory of results, plot the benchmarks for each task as a bar chart and line chart.
"""

import argparse
import os
from typing import Optional

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from dgeb import TaskResult, get_all_tasks, get_output_folder, get_tasks_by_name

ALL_TASKS = [task.metadata.id for task in get_all_tasks()]


def plot_benchmarks(
    results_dir,
    task_ids: Optional[list[str]] = None,
    output="benchmarks.png",
    model_substring=None,
):
    models = os.listdir(results_dir)
    all_results = []
    tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids)
    for model_name in models:
        if model_substring is not None and all(
            substr not in model_name for substr in model_substring
        ):
            continue

        for task in tasks:
            if task.metadata.display_name == "NoOp Task":
                continue
            filepath = get_output_folder(model_name, task, results_dir, create=False)
            # if the file does not exist, skip
            if not os.path.exists(filepath):
                continue

            with open(filepath) as f:
                task_result = TaskResult.model_validate_json(f.read())
            num_params = task_result.model["num_params"]
            primary_metric_id = task_result.task.primary_metric_id
            main_scores = [
                metric.value
                for layer_result in task_result.results
                for metric in layer_result.metrics
                if metric.id == primary_metric_id
            ]
            best_score = max(main_scores)
            all_results.append(
                {
                    "task": task.metadata.display_name,
                    "model": model_name,
                    "num_params": num_params,
                    "score": best_score,
                }
            )

    results_df = pd.DataFrame(all_results)
    # order the models by ascending number of parameters
    results_df["num_params"] = results_df["num_params"].astype(int)
    results_df = results_df.sort_values(by="num_params")
    # number of tasks
    n_tasks = len(set(results_df["task"]))

    _, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10))

    for i, task in enumerate(set(results_df["task"])):
        if n_tasks > 1:
            sns.barplot(
                x="model",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[0][i],
            )
            ax[0][i].set_title(task)
            # rotate the x axis labels
            for tick in ax[0][i].get_xticklabels():
                tick.set_rotation(90)
        else:
            sns.barplot(
                x="model",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[0],
            )
            ax[0].set_title(task)
            # rotate the x axis labels
            for tick in ax[0].get_xticklabels():
                tick.set_rotation(90)

    # make a line graph with number of parameters on x axis for each task in the second row of figures
    for i, task in enumerate(set(results_df["task"])):
        if n_tasks > 1:
            sns.lineplot(
                x="num_params",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[1][i],
            )
            ax[1][i].set_title(task)
            ax[1][i].set_xlabel("Number of parameters")
        else:
            sns.lineplot(
                x="num_params",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[1],
            )
            ax[1].set_title(task)
            ax[1].set_xlabel("Number of parameters")

    plt.tight_layout()
    plt.savefig(output)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--results_dir",
        type=str,
        default="results",
        help="Directory containing the results of the benchmarking",
    )
    parser.add_argument(
        "-t",
        "--tasks",
        type=lambda s: [item for item in s.split(",")],
        default=None,
        help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ",
    )
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default="benchmarks.png",
        help="Output file for the plot",
    )
    parser.add_argument(
        "--model_substring",
        type=lambda s: [item for item in s.split(",")],
        default=None,
        help="Comma separated list of model substrings. Only plot results for models containing this substring",
    )
    args = parser.parse_args()

    plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring)