ledmands commited on
Commit
ac49cb8
1 Parent(s): 449b9e9

plot_improvement works OK, could use fine tuning

Browse files
Files changed (1) hide show
  1. plot_improvement.py +44 -73
plot_improvement.py CHANGED
@@ -1,79 +1,50 @@
1
  import argparse
2
- from numpy import load, ndarray
3
  import os
4
-
5
- parser = argparse.ArgumentParser()
6
- parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str)
7
- args = parser.parse_args()
8
-
9
- filepath = args.filepath
10
- npdata = load(filepath)
11
-
12
- evaluations = ndarray.tolist(npdata['results'])
13
- # print(evaluations)
14
- sorted_evals = []
15
- for eval in evaluations:
16
- sorted_evals.append(sorted(eval))
17
-
18
- # Now I have a sorted list.
19
- # Now just pop the first and last elements of each eval
20
- for eval in sorted_evals:
21
- eval.pop(0)
22
- eval.pop()
23
-
24
- print()
25
- # print(sorted_evals)
26
-
27
- # Now that I have my sorted evaluations, I can calculate the mean episode reward for each eval
28
- mean_eval_rewards = []
29
- for eval in sorted_evals:
30
- mean_eval_rewards.append(sum(eval) / len(eval))
31
-
32
- # Now I should have a list with the mean evaluation reward with the highest and lowest score tossed out
33
- print(mean_eval_rewards)
34
- print("num evals: " + str(len(mean_eval_rewards)))
35
-
36
- # I'm dealing with a 2D array. Each element contains an array of ten data points
37
- # The number of elements is going to vary for each training run
38
- # The number of evaluation episodes will be constant, 10.
39
- # I need to convert to a regular list first
40
- # I could iterate over each element
41
-
42
- agent_dirs = []
43
  for d in os.listdir("agents/"):
44
  if "dqn_v2" in d:
45
- agent_dirs.append(d)
46
- # Now I have a list of dirs with the evals.
47
- # Iterate over the dirs, append the file path, load the evals, calculate the average score of the eval, then return a list with averages
48
- eval_list = []
49
- for d in agent_dirs:
50
- path = "agents/" + d + "/evaluations.npz"
51
- evals = ndarray.tolist(load(path)["results"])
52
- eval_list.append(evals)
53
- # for i in eval_list:
54
- # print(i)
55
- # print()
56
-
57
- def remove_outliers(evals: list) -> list:
58
- trimmed = []
59
- for eval in evals:
60
- eval.sort()
61
- eval.pop(0)
62
- eval.pop()
63
- trimmed.append(eval)
64
- return trimmed
65
 
66
- avgs = [[]]
67
- index = 0
68
- for i in eval_list:
69
- avgs.append(i)
70
- for j in i:
71
- j.sort()
72
- j.pop()
73
- j.pop(0)
74
- avgs[index].append(sum(j) / len(j))
75
- index += 1
76
-
77
- print(avgs)
78
 
79
-
 
1
  import argparse
2
+ import numpy as np
3
  import os
4
+ from matplotlib import pyplot as plt
5
+
6
+ def calc_stats(filepath):
7
+ # load the numpy file
8
+ data = np.load(filepath)["results"]
9
+ # sort the arrays and delete the first and last elements
10
+ data = np.sort(data, axis=1)
11
+ data = np.delete(data, -1, axis=1)
12
+ data = np.delete(data, 0, axis=1)
13
+ avg = round(np.mean(data), 2)
14
+ std = round(np.std(data), 2)
15
+ return avg, std
16
+
17
+ # parser = argparse.ArgumentParser()
18
+ # parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str)
19
+ # args = parser.parse_args()
20
+
21
+ # Get the file paths and store in list.
22
+ # For each file path, I want to calculate the mean reward. This would be the mean reward for the training run over all evaluations.
23
+ # For each file path, append the mean reward to an averages list
24
+ # Plot the averages!
25
+
26
+ filepaths = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  for d in os.listdir("agents/"):
28
  if "dqn_v2" in d:
29
+ path = "agents/" + d + "/evaluations.npz"
30
+ filepaths.append(path)
31
+
32
+ means = []
33
+ stds = []
34
+ for path in filepaths:
35
+ avg, std = calc_stats(path)
36
+ means.append(avg)
37
+ stds.append(std)
38
+
39
+ runs = []
40
+ for i in range(len(filepaths)):
41
+ runs.append(i + 1)
42
+ plt.xlabel("training runs")
43
+ plt.ylabel("score")
44
+ plt.bar(runs, means)
45
+ plt.bar(runs, stds)
46
+ plt.legend(["Mean evaluation score", "Standard deviation"])
47
+ plt.title("Average Evaluation Score and Standard Deviation\nAdjusted for Outliers\nAgent: dqn_v2")
48
+ plt.show()
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50