saattrupdan commited on
Commit
1ef58ee
·
1 Parent(s): 1c2b5d0

feat: Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +370 -0
  3. requirements.txt +69 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to produce radial plots."""
2
+
3
+ from functools import partial
4
+ import plotly.graph_objects as go
5
+ import json
6
+ import numpy as np
7
+ from collections import defaultdict
8
+ import pandas as pd
9
+ from pydantic import BaseModel
10
+ import gradio as gr
11
+ import requests
12
+
13
+
14
+ class Task(BaseModel):
15
+ """Class to hold task information."""
16
+
17
+ name: str
18
+ metric: str
19
+
20
+ def __hash__(self):
21
+ return hash(self.name)
22
+
23
+
24
+ class Language(BaseModel):
25
+ """Class to hold language information."""
26
+
27
+ code: str
28
+ name: str
29
+
30
+ def __hash__(self):
31
+ return hash(self.code)
32
+
33
+
34
+ class Dataset(BaseModel):
35
+ """Class to hold dataset information."""
36
+
37
+ name: str
38
+ language: Language
39
+ task: Task
40
+
41
+ def __hash__(self):
42
+ return hash(self.name)
43
+
44
+
45
+ TEXT_CLASSIFICATION = Task(name="text classification", metric="mcc")
46
+ INFORMATION_EXTRACTION = Task(name="information extraction", metric="micro_f1_no_misc")
47
+ GRAMMAR = Task(name="grammar", metric="mcc")
48
+ QUESTION_ANSWERING = Task(name="question answering", metric="em")
49
+ SUMMARISATION = Task(name="summarisation", metric="bertscore")
50
+ KNOWLEDGE = Task(name="knowledge", metric="mcc")
51
+ REASONING = Task(name="reasoning", metric="mcc")
52
+ ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)]
53
+
54
+ DANISH = Language(code="da", name="Danish")
55
+ NORWEGIAN = Language(code="no", name="Norwegian")
56
+ SWEDISH = Language(code="sv", name="Swedish")
57
+ ICELANDIC = Language(code="is", name="Icelandic")
58
+ FAROESE = Language(code="fo", name="Faroese")
59
+ GERMAN = Language(code="de", name="German")
60
+ DUTCH = Language(code="nl", name="Dutch")
61
+ ENGLISH = Language(code="en", name="English")
62
+ ALL_LANGUAGES = {
63
+ obj.name: obj for obj in globals().values() if isinstance(obj, Language)
64
+ }
65
+
66
+ DATASETS = [
67
+ Dataset(name="swerec", language=SWEDISH, task=TEXT_CLASSIFICATION),
68
+ Dataset(name="angry-tweets", language=DANISH, task=TEXT_CLASSIFICATION),
69
+ Dataset(name="norec", language=NORWEGIAN, task=TEXT_CLASSIFICATION),
70
+ Dataset(name="sb10k", language=GERMAN, task=TEXT_CLASSIFICATION),
71
+ Dataset(name="dutch-social", language=DUTCH, task=TEXT_CLASSIFICATION),
72
+ Dataset(name="sst5", language=ENGLISH, task=TEXT_CLASSIFICATION),
73
+ Dataset(name="suc3", language=SWEDISH, task=INFORMATION_EXTRACTION),
74
+ Dataset(name="dansk", language=DANISH, task=INFORMATION_EXTRACTION),
75
+ Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
76
+ Dataset(name="norne-nn", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
77
+ Dataset(name="mim-gold-ner", language=ICELANDIC, task=INFORMATION_EXTRACTION),
78
+ Dataset(name="fone", language=FAROESE, task=INFORMATION_EXTRACTION),
79
+ Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION),
80
+ Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION),
81
+ Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION),
82
+ Dataset(name="scala-sv", language=SWEDISH, task=GRAMMAR),
83
+ Dataset(name="scala-da", language=DANISH, task=GRAMMAR),
84
+ Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR),
85
+ Dataset(name="scala-nn", language=NORWEGIAN, task=GRAMMAR),
86
+ Dataset(name="scala-is", language=ICELANDIC, task=GRAMMAR),
87
+ Dataset(name="scala-fo", language=FAROESE, task=GRAMMAR),
88
+ Dataset(name="scala-de", language=GERMAN, task=GRAMMAR),
89
+ Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR),
90
+ Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR),
91
+ Dataset(name="scandiqa-da", language=DANISH, task=QUESTION_ANSWERING),
92
+ Dataset(name="norquad", language=NORWEGIAN, task=QUESTION_ANSWERING),
93
+ Dataset(name="scandiqa-sv", language=SWEDISH, task=QUESTION_ANSWERING),
94
+ Dataset(name="nqii", language=ICELANDIC, task=QUESTION_ANSWERING),
95
+ Dataset(name="germanquad", language=GERMAN, task=QUESTION_ANSWERING),
96
+ Dataset(name="squad", language=ENGLISH, task=QUESTION_ANSWERING),
97
+ Dataset(name="squad-nl", language=DUTCH, task=QUESTION_ANSWERING),
98
+ Dataset(name="nordjylland-news", language=DANISH, task=SUMMARISATION),
99
+ Dataset(name="mlsum", language=GERMAN, task=SUMMARISATION),
100
+ Dataset(name="rrn", language=ICELANDIC, task=SUMMARISATION),
101
+ Dataset(name="no-sammendrag", language=NORWEGIAN, task=SUMMARISATION),
102
+ Dataset(name="wiki-lingua-nl", language=DUTCH, task=SUMMARISATION),
103
+ Dataset(name="swedn", language=SWEDISH, task=SUMMARISATION),
104
+ Dataset(name="cnn-dailymail", language=ENGLISH, task=SUMMARISATION),
105
+ Dataset(name="mmlu-da", language=DANISH, task=KNOWLEDGE),
106
+ Dataset(name="mmlu-no", language=NORWEGIAN, task=KNOWLEDGE),
107
+ Dataset(name="mmlu-sv", language=SWEDISH, task=KNOWLEDGE),
108
+ Dataset(name="mmlu-is", language=ICELANDIC, task=KNOWLEDGE),
109
+ Dataset(name="mmlu-de", language=GERMAN, task=KNOWLEDGE),
110
+ Dataset(name="mmlu-nl", language=DUTCH, task=KNOWLEDGE),
111
+ Dataset(name="mmlu", language=ENGLISH, task=KNOWLEDGE),
112
+ Dataset(name="arc-da", language=DANISH, task=KNOWLEDGE),
113
+ Dataset(name="arc-no", language=NORWEGIAN, task=KNOWLEDGE),
114
+ Dataset(name="arc-sv", language=SWEDISH, task=KNOWLEDGE),
115
+ Dataset(name="arc-is", language=ICELANDIC, task=KNOWLEDGE),
116
+ Dataset(name="arc-de", language=GERMAN, task=KNOWLEDGE),
117
+ Dataset(name="arc-nl", language=DUTCH, task=KNOWLEDGE),
118
+ Dataset(name="arc", language=ENGLISH, task=KNOWLEDGE),
119
+ Dataset(name="hellaswag-da", language=DANISH, task=REASONING),
120
+ Dataset(name="hellaswag-no", language=NORWEGIAN, task=REASONING),
121
+ Dataset(name="hellaswag-sv", language=SWEDISH, task=REASONING),
122
+ Dataset(name="hellaswag-is", language=ICELANDIC, task=REASONING),
123
+ Dataset(name="hellaswag-de", language=GERMAN, task=REASONING),
124
+ Dataset(name="hellaswag-nl", language=DUTCH, task=REASONING),
125
+ Dataset(name="hellaswag", language=ENGLISH, task=REASONING),
126
+ ]
127
+
128
+
129
+ def main() -> None:
130
+ """Produce a radial plot."""
131
+
132
+ # Download all the newest records
133
+ response = requests.get("https://scandeval.com/scandeval_benchmark_results.jsonl")
134
+ response.raise_for_status()
135
+ records = [
136
+ json.loads(dct_str)
137
+ for dct_str in response.text.split("\n")
138
+ if dct_str.strip("\n")
139
+ ]
140
+
141
+ # Build a dictionary of languages -> results-dataframes, whose indices are the
142
+ # models and columns are the tasks.
143
+ results_dfs = dict()
144
+ for language in {dataset.language for dataset in DATASETS}:
145
+ possible_dataset_names = {
146
+ dataset.name for dataset in DATASETS if dataset.language == language
147
+ }
148
+ data_dict = defaultdict(dict)
149
+ for record in records:
150
+ model_name = record["model"]
151
+ dataset_name = record["dataset"]
152
+ if dataset_name in possible_dataset_names:
153
+ dataset = next(
154
+ dataset for dataset in DATASETS if dataset.name == dataset_name
155
+ )
156
+ results_dict = record['results']['total']
157
+ score = results_dict.get(
158
+ f"test_{dataset.task.metric}", results_dict.get(dataset.task.metric)
159
+ )
160
+ if dataset.task in data_dict[model_name]:
161
+ data_dict[model_name][dataset.task].append(score)
162
+ else:
163
+ data_dict[model_name][dataset.task] = [score]
164
+ results_df = pd.DataFrame(data_dict).T.map(
165
+ lambda list_or_nan:
166
+ np.mean(list_or_nan) if list_or_nan == list_or_nan else list_or_nan
167
+ ).dropna()
168
+ if any(task not in results_df.columns for task in ALL_TASKS):
169
+ results_dfs[language] = pd.DataFrame()
170
+ else:
171
+ results_dfs[language] = results_df
172
+
173
+ all_languages: list[str | int | float | tuple[str, str | int | float]] | None = [
174
+ language.name for language in ALL_LANGUAGES.values()
175
+ ]
176
+ all_models: list[str | int | float | tuple[str, str | int | float]] | None = list({
177
+ model_id
178
+ for df in results_dfs.values()
179
+ for model_id in df.index
180
+ })
181
+
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown("# Radial Plot Generator")
184
+ gr.Markdown("### Select the models and languages to include in the plot")
185
+ with gr.Row():
186
+ with gr.Column():
187
+ language_names_dropdown = gr.Dropdown(
188
+ choices=all_languages,
189
+ multiselect=True,
190
+ label="Languages",
191
+ value=["Danish"],
192
+ interactive=True,
193
+ )
194
+ model_ids_dropdown = gr.Dropdown(
195
+ choices=all_models,
196
+ multiselect=True,
197
+ label="Models",
198
+ value=["gpt-3.5-turbo-0613", "mistralai/Mistral-7B-v0.1"],
199
+ interactive=True,
200
+ )
201
+ use_win_ratio_checkbox = gr.Checkbox(
202
+ label="Compare models with win ratios (as opposed to raw scores)",
203
+ value=True,
204
+ interactive=True,
205
+ )
206
+ with gr.Column():
207
+ plot = gr.Plot(
208
+ value=produce_radial_plot(
209
+ model_ids_dropdown.value,
210
+ language_names=language_names_dropdown.value,
211
+ use_win_ratio=use_win_ratio_checkbox.value,
212
+ results_dfs=results_dfs,
213
+ ),
214
+ )
215
+
216
+ language_names_dropdown.change(
217
+ fn=partial(update_model_ids_dropdown, results_dfs=results_dfs),
218
+ inputs=language_names_dropdown,
219
+ outputs=model_ids_dropdown,
220
+ )
221
+
222
+ # Update plot when anything changes
223
+ language_names_dropdown.change(
224
+ fn=partial(produce_radial_plot, results_dfs=results_dfs),
225
+ inputs=[
226
+ model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox
227
+ ],
228
+ outputs=plot,
229
+ )
230
+ model_ids_dropdown.change(
231
+ fn=partial(produce_radial_plot, results_dfs=results_dfs),
232
+ inputs=[
233
+ model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox
234
+ ],
235
+ outputs=plot,
236
+ )
237
+ use_win_ratio_checkbox.change(
238
+ fn=partial(produce_radial_plot, results_dfs=results_dfs),
239
+ inputs=[
240
+ model_ids_dropdown, language_names_dropdown, use_win_ratio_checkbox
241
+ ],
242
+ outputs=plot,
243
+ )
244
+
245
+
246
+ demo.launch()
247
+
248
+
249
+ def update_model_ids_dropdown(
250
+ language_names: list[str], results_dfs: dict[Language, pd.DataFrame] | None
251
+ ) -> dict:
252
+ """When the language names are updated, update the model ids dropdown.
253
+
254
+ Args:
255
+ language_names:
256
+ The names of the languages to include in the plot.
257
+ results_dfs:
258
+ The results dataframes for each language.
259
+
260
+ Returns:
261
+ The Gradio update to the model ids dropdown.
262
+ """
263
+ if results_dfs is None or len(language_names) == 0:
264
+ return gr.update(choices=[], value=[])
265
+
266
+ filtered_models = list({
267
+ model_id
268
+ for language, df in results_dfs.items()
269
+ for model_id in df.index
270
+ if language.name in language_names
271
+ })
272
+
273
+ if len(filtered_models) == 0:
274
+ return gr.update(choices=[], value=[])
275
+
276
+ return gr.update(choices=filtered_models, value=filtered_models[0])
277
+
278
+
279
+ def produce_radial_plot(
280
+ model_ids: list[str],
281
+ language_names: list[str],
282
+ use_win_ratio: bool,
283
+ results_dfs: dict[Language, pd.DataFrame] | None
284
+ ) -> go.Figure:
285
+ """Produce a radial plot as a plotly figure.
286
+
287
+ Args:
288
+ model_ids:
289
+ The ids of the models to include in the plot.
290
+ language_names:
291
+ The names of the languages to include in the plot.
292
+ use_win_ratio:
293
+ Whether to use win ratios (as opposed to raw scores).
294
+ results_dfs:
295
+ The results dataframes for each language.
296
+
297
+ Returns:
298
+ A plotly figure.
299
+ """
300
+ if results_dfs is None or len(language_names) == 0 or len(model_ids) == 0:
301
+ return go.Figure()
302
+
303
+ tasks = ALL_TASKS
304
+ languages = [ALL_LANGUAGES[language_name] for language_name in language_names]
305
+
306
+ results_dfs_filtered = {
307
+ language: df
308
+ for language, df in results_dfs.items()
309
+ if language.name in language_names
310
+ }
311
+
312
+ # Add all the evaluation results for each model
313
+ results: list[list[float]] = list()
314
+ for model_id in model_ids:
315
+ result_list = list()
316
+ for task in tasks:
317
+ win_ratios = list()
318
+ scores = list()
319
+ for language in languages:
320
+ score = results_dfs_filtered[language].loc[model_id][task]
321
+ win_ratio = np.mean([
322
+ score >= other_score
323
+ for other_score in results_dfs_filtered[language][task].dropna()
324
+ ])
325
+ win_ratios.append(win_ratio)
326
+ scores.append(score)
327
+ if use_win_ratio:
328
+ result_list.append(np.mean(win_ratios))
329
+ else:
330
+ result_list.append(np.mean(scores))
331
+ results.append(result_list)
332
+
333
+ # Sort the results to avoid misleading radial plots
334
+ model_idx_with_highest_variance = np.argmax(
335
+ [np.std(result_list) for result_list in results]
336
+ )
337
+ sorted_idxs = np.argsort(results[model_idx_with_highest_variance])
338
+ results = [np.asarray(result_list)[sorted_idxs] for result_list in results]
339
+ tasks = np.asarray(tasks)[sorted_idxs]
340
+
341
+ # Add the results to a plotly figure
342
+ fig = go.Figure()
343
+ for model_id, result_list in zip(model_ids, results):
344
+ fig.add_trace(go.Scatterpolar(
345
+ r=result_list,
346
+ theta=[task.name for task in tasks],
347
+ fill='toself',
348
+ name=model_id,
349
+ ))
350
+
351
+ languages_str = ""
352
+ if len(languages) > 1:
353
+ languages_str = ", ".join([language.name for language in languages[:-1]])
354
+ languages_str += " and "
355
+ languages_str += languages[-1].name
356
+
357
+ if use_win_ratio:
358
+ title = f'Win Ratio on on {languages_str} Language Tasks'
359
+ else:
360
+ title = f'LLM Score on on {languages_str} Language Tasks'
361
+
362
+ # Builds the radial plot from the results
363
+ fig.update_layout(
364
+ polar=dict(radialaxis=dict(visible=True)), showlegend=True, title=title
365
+ )
366
+
367
+ return fig
368
+
369
+ if __name__ == "__main__":
370
+ main()
requirements.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==4.2.0
5
+ attrs==23.2.0
6
+ certifi==2023.11.17
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.2.0
11
+ cycler==0.12.1
12
+ exceptiongroup==1.2.0
13
+ fastapi==0.109.0
14
+ ffmpy==0.3.1
15
+ filelock==3.13.1
16
+ fonttools==4.47.2
17
+ fsspec==2023.12.2
18
+ gradio==4.15.0
19
+ gradio_client==0.8.1
20
+ h11==0.14.0
21
+ httpcore==1.0.2
22
+ httpx==0.26.0
23
+ huggingface-hub==0.20.3
24
+ idna==3.6
25
+ importlib-resources==6.1.1
26
+ Jinja2==3.1.3
27
+ jsonschema==4.21.1
28
+ jsonschema-specifications==2023.12.1
29
+ kiwisolver==1.4.5
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==2.1.4
32
+ matplotlib==3.8.2
33
+ mdurl==0.1.2
34
+ numpy==1.26.3
35
+ orjson==3.9.12
36
+ packaging==23.2
37
+ pandas==2.2.0
38
+ pillow==10.2.0
39
+ plotly==5.18.0
40
+ pyarrow==15.0.0
41
+ pydantic==2.5.3
42
+ pydantic_core==2.14.6
43
+ pydub==0.25.1
44
+ Pygments==2.17.2
45
+ pyparsing==3.1.1
46
+ python-dateutil==2.8.2
47
+ python-multipart==0.0.6
48
+ pytz==2023.3.post1
49
+ PyYAML==6.0.1
50
+ referencing==0.32.1
51
+ requests==2.31.0
52
+ rich==13.7.0
53
+ rpds-py==0.17.1
54
+ ruff==0.1.14
55
+ semantic-version==2.10.0
56
+ shellingham==1.5.4
57
+ six==1.16.0
58
+ sniffio==1.3.0
59
+ starlette==0.35.1
60
+ tenacity==8.2.3
61
+ tomlkit==0.12.0
62
+ toolz==0.12.1
63
+ tqdm==4.66.1
64
+ typer==0.9.0
65
+ typing_extensions==4.9.0
66
+ tzdata==2023.4
67
+ urllib3==2.1.0
68
+ uvicorn==0.27.0
69
+ websockets==11.0.3