Joschka Strueber commited on
Commit
4adb140
·
1 Parent(s): c8f741c

[Fix] rendering issues

Browse files
Files changed (4) hide show
  1. app.py +22 -12
  2. src/dataloading.py +1 -1
  3. src/heatmap.html +0 -0
  4. src/test.py +0 -14
app.py CHANGED
@@ -1,23 +1,21 @@
1
  import gradio as gr
2
  import plotly.graph_objects as go
3
- import plotly.io as pio
4
  import numpy as np
5
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
6
 
 
 
7
  pio.renderers.default = "iframe"
8
 
9
  def create_heatmap(selected_models, selected_dataset):
10
- # Return nothing if no inputs are provided
11
  if not selected_models or not selected_dataset:
12
  return None
13
 
14
- # Generate a random symmetric similarity matrix
15
  size = len(selected_models)
16
  similarities = np.random.rand(size, size)
17
  similarities = (similarities + similarities.T) / 2
18
  similarities = np.round(similarities, 2)
19
 
20
- # Create a heatmap trace using go.Heatmap; we set x and y to the model names.
21
  fig = go.Figure(data=go.Heatmap(
22
  z=similarities,
23
  x=selected_models,
@@ -27,8 +25,7 @@ def create_heatmap(selected_models, selected_dataset):
27
  text=similarities,
28
  hoverinfo="text"
29
  ))
30
-
31
- # Update layout: add title, axis titles, set fixed dimensions and margins
32
  fig.update_layout(
33
  title=f"Similarity Matrix for {selected_dataset}",
34
  xaxis_title="Models",
@@ -37,7 +34,24 @@ def create_heatmap(selected_models, selected_dataset):
37
  height=800,
38
  margin=dict(l=100, r=100, t=100, b=100)
39
  )
40
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def validate_inputs(selected_models, selected_dataset):
43
  if not selected_models:
@@ -47,7 +61,6 @@ def validate_inputs(selected_models, selected_dataset):
47
 
48
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
49
  gr.Markdown("## Model Similarity Comparison Tool")
50
-
51
  with gr.Row():
52
  dataset_dropdown = gr.Dropdown(
53
  choices=get_leaderboard_datasets(),
@@ -66,10 +79,8 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
66
  )
67
 
68
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
69
- # Initialize the Plot component without a figure (it will be updated)
70
  heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
71
 
72
- # First validate inputs, then create the heatmap; note that we use a single output.
73
  generate_btn.click(
74
  fn=validate_inputs,
75
  inputs=[model_dropdown, dataset_dropdown],
@@ -80,7 +91,6 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
80
  outputs=heatmap
81
  )
82
 
83
- # Clear button to reset selections and clear the plot
84
  clear_btn = gr.Button("Clear Selection")
85
  clear_btn.click(
86
  lambda: [None, None, None],
@@ -88,4 +98,4 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
88
  )
89
 
90
  if __name__ == "__main__":
91
- demo.launch(ssr_mode=False)
 
1
  import gradio as gr
2
  import plotly.graph_objects as go
 
3
  import numpy as np
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
+ # Force the Plotly renderer to use an iframe-based output
7
+ import plotly.io as pio
8
  pio.renderers.default = "iframe"
9
 
10
  def create_heatmap(selected_models, selected_dataset):
 
11
  if not selected_models or not selected_dataset:
12
  return None
13
 
 
14
  size = len(selected_models)
15
  similarities = np.random.rand(size, size)
16
  similarities = (similarities + similarities.T) / 2
17
  similarities = np.round(similarities, 2)
18
 
 
19
  fig = go.Figure(data=go.Heatmap(
20
  z=similarities,
21
  x=selected_models,
 
25
  text=similarities,
26
  hoverinfo="text"
27
  ))
28
+
 
29
  fig.update_layout(
30
  title=f"Similarity Matrix for {selected_dataset}",
31
  xaxis_title="Models",
 
34
  height=800,
35
  margin=dict(l=100, r=100, t=100, b=100)
36
  )
37
+
38
+ # (Optional) Force categorical ordering explicitly
39
+ fig.update_xaxes(
40
+ type="category",
41
+ categoryorder="array",
42
+ categoryarray=selected_models,
43
+ tickangle=45,
44
+ automargin=True
45
+ )
46
+ fig.update_yaxes(
47
+ type="category",
48
+ categoryorder="array",
49
+ categoryarray=selected_models,
50
+ automargin=True
51
+ )
52
+
53
+ # Return a fully serializable dictionary
54
+ return fig.to_dict()
55
 
56
  def validate_inputs(selected_models, selected_dataset):
57
  if not selected_models:
 
61
 
62
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
63
  gr.Markdown("## Model Similarity Comparison Tool")
 
64
  with gr.Row():
65
  dataset_dropdown = gr.Dropdown(
66
  choices=get_leaderboard_datasets(),
 
79
  )
80
 
81
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
 
82
  heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
83
 
 
84
  generate_btn.click(
85
  fn=validate_inputs,
86
  inputs=[model_dropdown, dataset_dropdown],
 
91
  outputs=heatmap
92
  )
93
 
 
94
  clear_btn = gr.Button("Clear Selection")
95
  clear_btn.click(
96
  lambda: [None, None, None],
 
98
  )
99
 
100
  if __name__ == "__main__":
101
+ demo.launch(ssr_mode=False)
src/dataloading.py CHANGED
@@ -6,7 +6,7 @@ def get_leaderboard_models():
6
  api = HfApi()
7
 
8
  # List all datasets in the open-llm-leaderboard organization
9
- datasets = api.list_datasets(author="open-llm-leaderboard")
10
 
11
  models = []
12
  #for dataset in datasets:
 
6
  api = HfApi()
7
 
8
  # List all datasets in the open-llm-leaderboard organization
9
+ #datasets = api.list_datasets(author="open-llm-leaderboard")
10
 
11
  models = []
12
  #for dataset in datasets:
src/heatmap.html DELETED
The diff for this file is too large to render. See raw diff
 
src/test.py DELETED
@@ -1,14 +0,0 @@
1
- import plotly.graph_objects as go
2
- import numpy as np
3
-
4
- models = ["model1", "model2", "model3"]
5
- size = len(models)
6
- sim = np.random.rand(size, size)
7
- sim = (sim + sim.T) / 2
8
- sim = np.round(sim, 2)
9
- fig = go.Figure(data=go.Heatmap(z=sim, x=models, y=models, colorscale="Viridis"))
10
- fig.update_layout(title="Test Heatmap", xaxis_title="Models", yaxis_title="Models", width=800, height=800)
11
- fig.show()
12
-
13
- # Save fig
14
- fig.write_html("heatmap.html")