Terry Zhuo
commited on
Commit
·
e50aaef
1
Parent(s):
9739b47
fix grid
Browse files- src/tools/plots.py +16 -10
src/tools/plots.py
CHANGED
@@ -21,23 +21,31 @@ def plot_solve_rate(df, task, rows=30, cols=38):
|
|
21 |
keys = df["task_id"]
|
22 |
values = df["solve_rate"]
|
23 |
|
24 |
-
values = np.array(values)
|
25 |
|
26 |
n = len(values)
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
hover_text = np.empty_like(
|
31 |
for i in range(rows):
|
32 |
for j in range(cols):
|
33 |
-
if not
|
34 |
-
hover_text[i, j] = f"{keys[i, j]}<br>Solve Rate: {
|
35 |
else:
|
36 |
hover_text[i, j] = "NaN"
|
37 |
|
38 |
-
|
|
|
|
|
39 |
fig = go.Figure(data=go.Heatmap(
|
40 |
-
z=
|
41 |
text=hover_text,
|
42 |
hoverinfo='text',
|
43 |
colorscale='teal',
|
@@ -52,8 +60,6 @@ def plot_solve_rate(df, task, rows=30, cols=38):
|
|
52 |
xaxis=dict(showticklabels=False),
|
53 |
yaxis=dict(showticklabels=False),
|
54 |
autosize=True,
|
55 |
-
# width=760,
|
56 |
-
# height=600,
|
57 |
)
|
58 |
|
59 |
return fig
|
|
|
21 |
keys = df["task_id"]
|
22 |
values = df["solve_rate"]
|
23 |
|
24 |
+
values = np.array(values, dtype=float) # Ensure values are floats
|
25 |
|
26 |
n = len(values)
|
27 |
+
pad_width = rows * cols - n
|
28 |
+
|
29 |
+
# Use masked array to handle NaN values
|
30 |
+
masked_values = np.ma.array(values)
|
31 |
+
masked_values = np.ma.pad(masked_values, (0, pad_width), 'constant', constant_values=np.ma.masked)
|
32 |
+
masked_values = masked_values.reshape((rows, cols))
|
33 |
+
|
34 |
+
keys = np.pad(keys, (0, pad_width), 'constant', constant_values='').reshape((rows, cols))
|
35 |
|
36 |
+
hover_text = np.empty_like(masked_values, dtype=object)
|
37 |
for i in range(rows):
|
38 |
for j in range(cols):
|
39 |
+
if not masked_values.mask[i, j]:
|
40 |
+
hover_text[i, j] = f"{keys[i, j]}<br>Solve Rate: {masked_values[i, j]:.2f}"
|
41 |
else:
|
42 |
hover_text[i, j] = "NaN"
|
43 |
|
44 |
+
# Use compressed array to count non-masked (finite) values
|
45 |
+
upper_solve_rate = round(np.count_nonzero(~masked_values.mask) / n * 100, 2)
|
46 |
+
|
47 |
fig = go.Figure(data=go.Heatmap(
|
48 |
+
z=masked_values,
|
49 |
text=hover_text,
|
50 |
hoverinfo='text',
|
51 |
colorscale='teal',
|
|
|
60 |
xaxis=dict(showticklabels=False),
|
61 |
yaxis=dict(showticklabels=False),
|
62 |
autosize=True,
|
|
|
|
|
63 |
)
|
64 |
|
65 |
return fig
|