File size: 1,329 Bytes
47aa47a ac49cb8 e036817 ac49cb8 71d2358 ac49cb8 e036817 ac49cb8 47aa47a ac49cb8 71d2358 ac49cb8 71d2358 e036817 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# import argparse
import numpy as np
import os
from matplotlib import pyplot as plt
def calc_stats(filepath):
data = np.load(filepath)["results"]
# sort the arrays and delete the first and last elements
data = np.sort(data, axis=1)
data = np.delete(data, -1, axis=1)
data = np.delete(data, 0, axis=1)
avg = round(np.mean(data), 2)
std = round(np.std(data), 2)
return avg, std
# parser = argparse.ArgumentParser()
# parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str)
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True)
# args = parser.parse_args()
filepaths = []
for d in os.listdir("agents/"):
if "dqn_v2" in d:
path = "agents/" + d + "/evaluations.npz"
filepaths.append(path)
means = []
stds = []
for path in filepaths:
avg, std = calc_stats(path)
means.append(avg)
stds.append(std)
runs = []
for i in range(len(filepaths)):
runs.append(i + 1)
plt.xlabel("Training Run")
plt.ylabel("Score")
plt.bar(runs, means)
plt.bar(runs, stds)
plt.legend(["Mean evaluation score", "Standard deviation"])
plt.title("Average Evaluation Score and Standard Deviation\nAdjusted for Outliers Agent: dqn_v2")
plt.show()
# plt.savefig("charts/fig1")
|