# import argparse | |
import numpy as np | |
import os | |
from matplotlib import pyplot as plt | |
def calc_stats(filepath): | |
data = np.load(filepath)["results"] | |
# sort the arrays and delete the first and last elements | |
data = np.sort(data, axis=1) | |
data = np.delete(data, -1, axis=1) | |
data = np.delete(data, 0, axis=1) | |
avg = round(np.mean(data), 2) | |
std = round(np.std(data), 2) | |
return avg, std | |
# parser = argparse.ArgumentParser() | |
# parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str) | |
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True) | |
# args = parser.parse_args() | |
filepaths = [] | |
for d in os.listdir("agents/"): | |
if "dqn_v2" in d: | |
path = "agents/" + d + "/evaluations.npz" | |
filepaths.append(path) | |
means = [] | |
stds = [] | |
for path in filepaths: | |
avg, std = calc_stats(path) | |
means.append(avg) | |
stds.append(std) | |
runs = [] | |
for i in range(len(filepaths)): | |
runs.append(i + 1) | |
plt.xlabel("Training Run") | |
plt.ylabel("Score") | |
plt.bar(runs, means) | |
plt.bar(runs, stds) | |
plt.legend(["Mean evaluation score", "Standard deviation"]) | |
plt.title("Average Evaluation Score and Standard Deviation\nAdjusted for Outliers Agent: dqn_v2") | |
plt.show() | |
# plt.savefig("charts/fig1") | |