|
import datetime |
|
import json |
|
import zipfile |
|
import stable_baselines3 |
|
|
|
|
|
def generate_config_json(model_fp, config_fp): |
|
with zipfile.ZipFile(model_fp, 'r') as zip_ref: |
|
with zip_ref.open("data") as file: |
|
data = json.load(file) |
|
data["system_info"] = stable_baselines3.get_system_info(print_info=False)[0] |
|
with open(config_fp, 'w') as f: |
|
json.dump(data, f, indent=4) |
|
|
|
def generate_results_json(results_fp, mean_reward, std_reward, n_eval_episodes, is_deterministic=True): |
|
eval_form_datetime = datetime.datetime.now().isoformat() |
|
data = { |
|
"mean_reward": mean_reward, |
|
"std_reward": std_reward, |
|
"is_deterministic": is_deterministic, |
|
"n_eval_episodes": n_eval_episodes, |
|
"eval_datetime": eval_form_datetime, |
|
} |
|
with open(results_fp, 'w') as f: |
|
json.dump(data, f, indent=4) |
|
|