Terry Zhuo commited on
Commit
e50aaef
·
1 Parent(s): 9739b47
Files changed (1) hide show
  1. src/tools/plots.py +16 -10
src/tools/plots.py CHANGED
@@ -21,23 +21,31 @@ def plot_solve_rate(df, task, rows=30, cols=38):
21
  keys = df["task_id"]
22
  values = df["solve_rate"]
23
 
24
- values = np.array(values)
25
 
26
  n = len(values)
27
- values = np.pad(values, (0, rows * cols - n), 'constant', constant_values=np.nan).reshape((rows, cols))
28
- keys = np.pad(keys, (0, rows * cols - n), 'constant', constant_values='').reshape((rows, cols))
 
 
 
 
 
 
29
 
30
- hover_text = np.empty_like(values, dtype=object)
31
  for i in range(rows):
32
  for j in range(cols):
33
- if not np.isnan(values[i, j]):
34
- hover_text[i, j] = f"{keys[i, j]}<br>Solve Rate: {values[i, j]:.2f}"
35
  else:
36
  hover_text[i, j] = "NaN"
37
 
38
- upper_solve_rate = round(np.count_nonzero(values)/n*100, 2)
 
 
39
  fig = go.Figure(data=go.Heatmap(
40
- z=values,
41
  text=hover_text,
42
  hoverinfo='text',
43
  colorscale='teal',
@@ -52,8 +60,6 @@ def plot_solve_rate(df, task, rows=30, cols=38):
52
  xaxis=dict(showticklabels=False),
53
  yaxis=dict(showticklabels=False),
54
  autosize=True,
55
- # width=760,
56
- # height=600,
57
  )
58
 
59
  return fig
 
21
  keys = df["task_id"]
22
  values = df["solve_rate"]
23
 
24
+ values = np.array(values, dtype=float) # Ensure values are floats
25
 
26
  n = len(values)
27
+ pad_width = rows * cols - n
28
+
29
+ # Use masked array to handle NaN values
30
+ masked_values = np.ma.array(values)
31
+ masked_values = np.ma.pad(masked_values, (0, pad_width), 'constant', constant_values=np.ma.masked)
32
+ masked_values = masked_values.reshape((rows, cols))
33
+
34
+ keys = np.pad(keys, (0, pad_width), 'constant', constant_values='').reshape((rows, cols))
35
 
36
+ hover_text = np.empty_like(masked_values, dtype=object)
37
  for i in range(rows):
38
  for j in range(cols):
39
+ if not masked_values.mask[i, j]:
40
+ hover_text[i, j] = f"{keys[i, j]}<br>Solve Rate: {masked_values[i, j]:.2f}"
41
  else:
42
  hover_text[i, j] = "NaN"
43
 
44
+ # Use compressed array to count non-masked (finite) values
45
+ upper_solve_rate = round(np.count_nonzero(~masked_values.mask) / n * 100, 2)
46
+
47
  fig = go.Figure(data=go.Heatmap(
48
+ z=masked_values,
49
  text=hover_text,
50
  hoverinfo='text',
51
  colorscale='teal',
 
60
  xaxis=dict(showticklabels=False),
61
  yaxis=dict(showticklabels=False),
62
  autosize=True,
 
 
63
  )
64
 
65
  return fig