File size: 2,255 Bytes
3261e0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import random
from hf_helpers.sb3_eval import eval_model_with_seed
import pandas as pd
env_id = "LunarLander-v2"
models_to_evaluate = [
"ppo-LunarLander-v2_001_000_000_hf_defaults.zip",
"ppo-LunarLander-v2_010_000_000_hf_defaults.zip",
"ppo-LunarLander-v2_010_000_000_sb3_defaults.zip",
"ppo-LunarLander-v2_123_456_789_hf_defaults.zip",
]
evaluation_results_fp = "evaluation_results.csv"
def store_results(results):
results_df = pd.DataFrame(results)
header = False if os.path.exists(evaluation_results_fp) else True
results_df.to_csv(evaluation_results_fp, mode="a", index=False, header=header)
def evaluate_and_store_all_results():
results = []
n_evaluations = 1000
for i in range(n_evaluations):
if i > 0 and i % 10 == 0:
print(f"Progress: {i}/{n_evaluations}")
store_results(results)
results = []
# seed = random.randint(0, 1000000000000) # Why this interval?
seed = random.randint(0, 10000) # Also try some smaller numbers for seed
n_envs = random.randint(1, 16)
for model_fp in models_to_evaluate:
result, mean_reward, std_reward = eval_model_with_seed(
model_fp, env_id, seed, n_eval_episodes=10, n_envs=n_envs
)
result_data = {
"model_fp": model_fp,
"seed": seed,
"n_envs": n_envs,
"result": result,
"mean_reward": mean_reward,
"std_reward": std_reward,
}
results.append(result_data)
def analyze_results():
results_df = pd.read_csv(evaluation_results_fp)
results_df["model_fp"] = results_df["model_fp"].str.replace(".zip", "", regex=False)
aggregated_results = (
results_df.groupby("model_fp")["result"]
.agg(["count", "min", "max", "mean"])
.reset_index()
)
aggregated_results.columns = [
"Model name",
"Number of results",
"Min",
"Max",
"Average",
]
aggregated_results = aggregated_results.sort_values(by="Model name")
print(aggregated_results.to_markdown(index=False, tablefmt="pipe"))
# evaluate_and_store_all_results()
analyze_results()
|