Commit
·
c34e772
1
Parent(s):
8157f53
feat: Add update colours button
Browse files
app.py
CHANGED
@@ -232,40 +232,30 @@ DATASETS = [
|
|
232 |
]
|
233 |
|
234 |
|
235 |
-
def
|
236 |
-
"""
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
241 |
|
242 |
-
|
243 |
-
[language.name for language in ALL_LANGUAGES.values()],
|
244 |
-
key=lambda language_name: language_name.lower(),
|
245 |
-
)
|
246 |
-
danish_models = sorted(
|
247 |
-
list({model_id for model_id in results_dfs[DANISH].index}),
|
248 |
-
key=lambda model_id: model_id.lower(),
|
249 |
-
)
|
250 |
|
251 |
# Get distinct RGB values for all models
|
252 |
all_models = list(
|
253 |
{model_id for df in results_dfs.values() for model_id in df.index}
|
254 |
)
|
255 |
-
colour_mapping
|
256 |
|
257 |
for i in it.count():
|
258 |
min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
|
259 |
-
|
260 |
-
if i > 0:
|
261 |
-
logger.info(
|
262 |
-
f"All retries failed. Trying again with min colour distance "
|
263 |
-
f"{min_colour_distance}."
|
264 |
-
)
|
265 |
-
|
266 |
retries_left = 10 * len(all_models)
|
267 |
for model_id in all_models:
|
268 |
-
random.seed(hash(model_id) + i)
|
269 |
r, g, b = 0, 0, 0
|
270 |
too_bright, similar_to_other_model = True, True
|
271 |
while (too_bright or similar_to_other_model) and retries_left > 0:
|
@@ -287,6 +277,28 @@ def main() -> None:
|
|
287 |
)
|
288 |
break
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
291 |
gr.Markdown(INTRO_MARKDOWN)
|
292 |
|
@@ -340,6 +352,11 @@ def main() -> None:
|
|
340 |
interactive=True,
|
341 |
scale=1,
|
342 |
)
|
|
|
|
|
|
|
|
|
|
|
343 |
with gr.Row():
|
344 |
plot = gr.Plot(
|
345 |
value=produce_radial_plot(
|
@@ -349,7 +366,6 @@ def main() -> None:
|
|
349 |
show_scale=show_scale_checkbox.value,
|
350 |
plot_width=plot_width_slider.value,
|
351 |
plot_height=plot_height_slider.value,
|
352 |
-
colour_mapping=colour_mapping,
|
353 |
results_dfs=results_dfs,
|
354 |
),
|
355 |
)
|
@@ -371,7 +387,6 @@ def main() -> None:
|
|
371 |
update_plot_kwargs = dict(
|
372 |
fn=partial(
|
373 |
produce_radial_plot,
|
374 |
-
colour_mapping=colour_mapping,
|
375 |
results_dfs=results_dfs,
|
376 |
),
|
377 |
inputs=[
|
@@ -391,6 +406,11 @@ def main() -> None:
|
|
391 |
plot_width_slider.change(**update_plot_kwargs)
|
392 |
plot_height_slider.change(**update_plot_kwargs)
|
393 |
|
|
|
|
|
|
|
|
|
|
|
394 |
demo.launch()
|
395 |
|
396 |
|
@@ -483,7 +503,6 @@ def produce_radial_plot(
|
|
483 |
show_scale: bool,
|
484 |
plot_width: int,
|
485 |
plot_height: int,
|
486 |
-
colour_mapping: dict[str, tuple[int, int, int]],
|
487 |
results_dfs: dict[Language, pd.DataFrame] | None,
|
488 |
) -> go.Figure:
|
489 |
"""Produce a radial plot as a plotly figure.
|
@@ -501,8 +520,6 @@ def produce_radial_plot(
|
|
501 |
The width of the plot.
|
502 |
plot_height:
|
503 |
The height of the plot.
|
504 |
-
colour_mapping:
|
505 |
-
A mapping from model ids to RGB triplets.
|
506 |
results_dfs:
|
507 |
The results dataframes for each language.
|
508 |
|
|
|
232 |
]
|
233 |
|
234 |
|
235 |
+
def update_colour_mapping(results_dfs: dict[Language, pd.DataFrame]) -> None:
|
236 |
+
"""Get a mapping from model ids to RGB triplets.
|
237 |
|
238 |
+
Args:
|
239 |
+
results_dfs:
|
240 |
+
The results dataframes for each language.
|
241 |
+
"""
|
242 |
+
global colour_mapping
|
243 |
+
global seed
|
244 |
+
seed += 1
|
245 |
|
246 |
+
gr.Info(f"Updating colour mapping...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
# Get distinct RGB values for all models
|
249 |
all_models = list(
|
250 |
{model_id for df in results_dfs.values() for model_id in df.index}
|
251 |
)
|
252 |
+
colour_mapping = dict()
|
253 |
|
254 |
for i in it.count():
|
255 |
min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
retries_left = 10 * len(all_models)
|
257 |
for model_id in all_models:
|
258 |
+
random.seed(hash(model_id) + i + seed)
|
259 |
r, g, b = 0, 0, 0
|
260 |
too_bright, similar_to_other_model = True, True
|
261 |
while (too_bright or similar_to_other_model) and retries_left > 0:
|
|
|
277 |
)
|
278 |
break
|
279 |
|
280 |
+
|
281 |
+
def main() -> None:
|
282 |
+
"""Produce a radial plot."""
|
283 |
+
|
284 |
+
global last_fetch
|
285 |
+
results_dfs = fetch_results()
|
286 |
+
last_fetch = dt.datetime.now()
|
287 |
+
|
288 |
+
all_languages = sorted(
|
289 |
+
[language.name for language in ALL_LANGUAGES.values()],
|
290 |
+
key=lambda language_name: language_name.lower(),
|
291 |
+
)
|
292 |
+
danish_models = sorted(
|
293 |
+
list({model_id for model_id in results_dfs[DANISH].index}),
|
294 |
+
key=lambda model_id: model_id.lower(),
|
295 |
+
)
|
296 |
+
|
297 |
+
global colour_mapping
|
298 |
+
global seed
|
299 |
+
seed = 4242
|
300 |
+
update_colour_mapping(results_dfs=results_dfs)
|
301 |
+
|
302 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
303 |
gr.Markdown(INTRO_MARKDOWN)
|
304 |
|
|
|
352 |
interactive=True,
|
353 |
scale=1,
|
354 |
)
|
355 |
+
update_colours_button = gr.Button(
|
356 |
+
value="Update colours",
|
357 |
+
interactive=True,
|
358 |
+
scale=1,
|
359 |
+
)
|
360 |
with gr.Row():
|
361 |
plot = gr.Plot(
|
362 |
value=produce_radial_plot(
|
|
|
366 |
show_scale=show_scale_checkbox.value,
|
367 |
plot_width=plot_width_slider.value,
|
368 |
plot_height=plot_height_slider.value,
|
|
|
369 |
results_dfs=results_dfs,
|
370 |
),
|
371 |
)
|
|
|
387 |
update_plot_kwargs = dict(
|
388 |
fn=partial(
|
389 |
produce_radial_plot,
|
|
|
390 |
results_dfs=results_dfs,
|
391 |
),
|
392 |
inputs=[
|
|
|
406 |
plot_width_slider.change(**update_plot_kwargs)
|
407 |
plot_height_slider.change(**update_plot_kwargs)
|
408 |
|
409 |
+
# Update colours when the button is clicked
|
410 |
+
update_colours_button.click(
|
411 |
+
fn=partial(update_colour_mapping, results_dfs=results_dfs),
|
412 |
+
).then(**update_plot_kwargs)
|
413 |
+
|
414 |
demo.launch()
|
415 |
|
416 |
|
|
|
503 |
show_scale: bool,
|
504 |
plot_width: int,
|
505 |
plot_height: int,
|
|
|
506 |
results_dfs: dict[Language, pd.DataFrame] | None,
|
507 |
) -> go.Figure:
|
508 |
"""Produce a radial plot as a plotly figure.
|
|
|
520 |
The width of the plot.
|
521 |
plot_height:
|
522 |
The height of the plot.
|
|
|
|
|
523 |
results_dfs:
|
524 |
The results dataframes for each language.
|
525 |
|