Spaces:
Running
Running
#!/usr/bin/env python | |
import re | |
import itertools | |
import math | |
from itertools import chain | |
import time | |
# import seaborn | |
import numpy as np | |
import os | |
from collections import OrderedDict, defaultdict | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import sys | |
from termcolor import cprint, colored | |
from pathlib import Path | |
import pickle | |
eval_metric = "test_success_rates" | |
# eval_metric = "exploration_bonus_mean" | |
super_title = "" | |
# super_title = "PPO - No exploration bonus" | |
# super_title = "Count Based exploration bonus (Grid Search)" | |
# super_title = "PPO + RND" | |
# super_title = "PPO + RIDE" | |
agg_title = "" | |
color_dict = None | |
eval_filename = None | |
max_frames = 20_000_000 | |
draw_legend = True | |
per_seed = False | |
study_eval = True | |
plot_train = True | |
plot_test = True | |
plot_aggregated_test = False | |
plot_only_aggregated_test = False | |
train_inc_font = 3 | |
xnbins = 4 | |
ynbins = 3 | |
steps_denom = 1e6 | |
# Global vas for tracking and labeling data at load time. | |
exp_idx = 0 | |
label_parser_dict = None | |
label_parser = lambda l, _, label_parser_dict: l | |
# smooth_factor = 100 | |
smooth_factor = 10 | |
smooth_factor = 0 | |
print("smooth factor:", smooth_factor) | |
eval_smooth_factor = 1 | |
leg_size = 30 | |
def smooth(x_, n=50): | |
if type(x_) == list: | |
x_ = np.array(x_) | |
return np.array([x_[max(i - n, 0):i + 1].mean() for i in range(len(x_))]) | |
sort_test = False | |
def sort_test_set(env_name): | |
helps = [ | |
"LanguageFeedback", | |
"LanguageColor", | |
"Pointing", | |
"Emulation", | |
] | |
problems = [ | |
"Boxes", | |
"Switches", | |
"Generators", | |
"Marble", | |
"Doors", | |
"Levers", | |
] | |
env_names = [] | |
for p in problems: | |
for h in helps: | |
env_names.append(h+p) | |
env_names.extend([ | |
"LeverDoorColl", | |
"MarblePushColl", | |
"MarblePassColl", | |
"AppleStealing" | |
]) | |
for i, en in enumerate(env_names): | |
if en in env_name: | |
return i | |
raise ValueError(f"Test env {env_name} not known") | |
subsample_step = 1 | |
load_subsample_step = 1 | |
x_lim = 0 | |
max_x_lim = 17 | |
max_x_lim = np.inf | |
# x_lim = 100 | |
summary_dict = {} | |
summary_dict_colors = {} | |
# default_colors = ["blue","orange","green","magenta", "brown", "red",'black',"grey",u'#ff7f0e', | |
# "cyan", "pink",'purple', u'#1f77b4', | |
# "darkorchid","sienna","lightpink", "indigo","mediumseagreen",'aqua', | |
# 'deeppink','silver','khaki','goldenrod','y','y','y','y','y','y','y','y','y','y','y','y' ] + ['y']*50 | |
default_colors_ = ["blue","orange","green","magenta", "brown", "red",'black',"grey",u'#ff7f0e', | |
"cyan", "pink",'purple', u'#1f77b4', | |
"darkorchid","sienna","lightpink", "indigo","mediumseagreen",'aqua', | |
'deeppink','silver','khaki','goldenrod'] * 100 | |
def get_eval_data(logdir, eval_metric): | |
eval_data = defaultdict(lambda :defaultdict(list)) | |
for root, _, files in os.walk(logdir): | |
for file in files: | |
if 'testing_' in file: | |
assert ".pkl" in file | |
test_env_name = file.lstrip("testing_").rstrip(".pkl") | |
try: | |
with open(root+"/"+file, "rb") as f: | |
seed_eval_data = pickle.load(f) | |
except: | |
print("Pickle not loaded: ", root+"/"+file) | |
time.sleep(1) | |
continue | |
eval_data[test_env_name]["values"].append(seed_eval_data[eval_metric]) | |
eval_data[test_env_name]["steps"].append(seed_eval_data["test_step_nb"]) | |
# if 'log.csv' in files: | |
# run_name = root[8:] | |
# exp_name = None | |
# | |
# config = None | |
# exp_idx += 1 | |
# | |
# # load progress data | |
# try: | |
# print(os.path.join(root, 'log.csv')) | |
# exp_data = pd.read_csv(os.path.join(root, 'log.csv')) | |
# except: | |
# size = (Path(root) / 'log.csv').stat().st_size | |
# if size == 0: | |
# raise ValueError("CSV {} empty".format(os.path.join(root, 'log.csv'))) | |
# else: | |
# raise ValueError("CSV {} faulty".format(os.path.join(root, 'log.csv'))) | |
# | |
# exp_data = exp_data[::load_subsample_step] | |
# data_dict = exp_data.to_dict("list") | |
# | |
# data_dict['config'] = config | |
# nb_epochs = len(data_dict['frames']) | |
# print('{} -> {}'.format(run_name, nb_epochs)) | |
for test_env, seed_data in eval_data.items(): | |
min_len_seed = min([len(s) for s in seed_data['steps']]) | |
eval_data[test_env]["values"] = np.array([s[:min_len_seed] for s in eval_data[test_env]["values"]]) | |
eval_data[test_env]["steps"] = np.array([s[:min_len_seed] for s in eval_data[test_env]["steps"]]) | |
return eval_data | |
def get_all_runs(logdir, load_subsample_step=1): | |
""" | |
Recursively look through logdir for output files produced by | |
Assumes that any file "log.csv" is a valid hit. | |
""" | |
global exp_idx | |
global units | |
datasets = [] | |
for root, _, files in os.walk(logdir): | |
if 'log.csv' in files: | |
if (Path(root) / 'log.csv').stat().st_size == 0: | |
print("CSV {} empty".format(os.path.join(root, 'log.csv'))) | |
continue | |
run_name = root[8:] | |
exp_name = None | |
config = None | |
exp_idx += 1 | |
# load progress data | |
try: | |
exp_data = pd.read_csv(os.path.join(root, 'log.csv')) | |
print("Loaded:", os.path.join(root, 'log.csv')) | |
except: | |
raise ValueError("CSV {} faulty".format(os.path.join(root, 'log.csv'))) | |
exp_data = exp_data[::load_subsample_step] | |
data_dict = exp_data.to_dict("list") | |
data_dict['config'] = config | |
nb_epochs = len(data_dict['frames']) | |
if nb_epochs == 1: | |
print(f'{run_name} -> {colored(f"nb_epochs {nb_epochs}", "red")}') | |
else: | |
print('{} -> nb_epochs {}'.format(run_name, nb_epochs)) | |
datasets.append(data_dict) | |
return datasets | |
def get_datasets(rootdir, load_only="", load_subsample_step=1, ignore_patterns=("ignore"), require_patterns=()): | |
_, models_list, _ = next(os.walk(rootdir)) | |
for dir_name in models_list.copy(): | |
# add "ignore" in a directory name to avoid loading its content | |
for ignore_pattern in ignore_patterns: | |
if ignore_pattern in dir_name or load_only not in dir_name: | |
if dir_name in models_list: | |
models_list.remove(dir_name) | |
if len(require_patterns) > 0: | |
if not any([require_pattern in dir_name for require_pattern in require_patterns]): | |
if dir_name in models_list: | |
models_list.remove(dir_name) | |
for expe_name in list(labels.keys()): | |
if expe_name not in models_list: | |
del labels[expe_name] | |
# setting per-model type colors | |
for i, m_name in enumerate(models_list): | |
for m_type, m_color in per_model_colors.items(): | |
if m_type in m_name: | |
colors[m_name] = m_color | |
print("extracting data for {}...".format(m_name)) | |
m_id = m_name | |
models_saves[m_id] = OrderedDict() | |
models_saves[m_id]['data'] = get_all_runs(rootdir+m_name, load_subsample_step=load_subsample_step) | |
print("done") | |
if m_name not in labels: | |
labels[m_name] = m_name | |
model_eval_data[m_id] = get_eval_data(logdir=rootdir+m_name, eval_metric=eval_metric) | |
""" | |
retrieve all experiences located in "data to vizu" folder | |
""" | |
labels = OrderedDict() | |
per_model_colors = OrderedDict() | |
# per_model_colors = OrderedDict([('ALP-GMM',u'#1f77b4'), | |
# ('hmn','pink'), | |
# ('ADR','black')]) | |
# LOAD DATA | |
models_saves = OrderedDict() | |
colors = OrderedDict() | |
model_eval_data = OrderedDict() | |
static_lines = {} | |
# get_datasets("storage/",load_only="RERUN_WizardGuide") | |
# get_datasets("storage/",load_only="RERUN_WizardTwoGuides") | |
try: | |
load_pattern = eval(sys.argv[1]) | |
except: | |
load_pattern = sys.argv[1] | |
ignore_patterns = ["_ignore_"] | |
require_patterns = [ | |
"_" | |
] | |
# require_patterns = [ | |
# "dummy_cs_jz_scaf_A_E_N_A_E", | |
# "03-12_dummy_cs_jz_formats_AE", | |
# ] | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# if "single" in label: | |
# ty = "single" | |
# elif "group" in label: | |
# ty = "group" | |
# | |
# if "asoc" in label: | |
# return f"Asocial_pretrain({ty})" | |
# | |
# if "exp_soc" in label: | |
# return f"Role_B_pretrain({ty})" | |
# | |
# return label | |
# | |
# # DUMMY FORMATS | |
# require_patterns = [ | |
# "03-12_dummy_cs_formats_CBL", | |
# "dummy_cs_formats_CBL_N_rec_5" | |
# "03-12_dummy_cs_jz_formats_", | |
# "dummy_cs_jz_formats_N_rec_5" | |
# ] | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# if "CBL" in label: | |
# eb = "CBL" | |
# else: | |
# eb = "no_bonus" | |
# | |
# if "AE" in label: | |
# label = f"AE_PPO_{eb}" | |
# elif "E" in label: | |
# label = f"E_PPO_{eb}" | |
# elif "A" in label: | |
# label = f"A_PPO_{eb}" | |
# elif "N" in label: | |
# label = f"N_PPO_{eb}" | |
# | |
# return label | |
# | |
# DUMMY CLASSIC | |
# require_patterns = [ | |
# "07-12_dummy_cs_NEW2_Pointing_sm_CB_very_small", | |
# "dummy_cs_JA_Pointing_CB_sm", | |
# "06-12_dummy_cs_NEW_Color_CBL", | |
# "dummy_cs_JA_Color_CBL_new" | |
# "07-12_dummy_cs_NEW2_Feedback_CBL", | |
# "dummy_cs_JA_Feedback_CBL_new" | |
# "08-12_dummy_cs_emulation_no_distr_rec_5_CB_exploration-bonus-type_cell_exploration-bonus-params__1_50", | |
# "08-12_dummy_cs_emulation_no_distr_rec_5_CB", | |
# "dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_exp_soc", | |
# "dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_contr_asoc", | |
# "dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_exp_soc", | |
# "dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_contr_asoc" | |
# "03-12_dummy_cs_jz_formats_A", | |
# "03-12_dummy_cs_jz_formats_E", | |
# "03-12_dummy_cs_jz_formats_AE", | |
# "dummy_cs_jz_formats_N_rec_5" | |
# "03-12_dummy_cs_formats_CBL_A", | |
# "03-12_dummy_cs_formats_CBL_E", | |
# "03-12_dummy_cs_formats_CBL_AE", | |
# "dummy_cs_formats_CBL_N_rec_5" | |
# "03-12_dummy_cs_jz_formats_AE", | |
# "dummy_cs_jz_scaf_A_E_N_A_E_full-AEfull", | |
# "dummy_cs_jz_scaf_A_E_N_A_E_scaf_full-AEfull", | |
# ] | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.replace("07-12_dummy_cs_NEW2_Pointing_sm_CB_very_small", "PPO_CB") | |
# label = label.replace("dummy_cs_JA_Pointing_CB_sm", "JA_PPO_CB") | |
# | |
# label = label.replace("06-12_dummy_cs_NEW_Color_CBL", "PPO_CBL") | |
# label = label.replace("dummy_cs_JA_Color_CBL_new", "JA_PPO_CBL") | |
# | |
# label = label.replace("07-12_dummy_cs_NEW2_Feedback_CBL", "PPO_CBL") | |
# label = label.replace("dummy_cs_JA_Feedback_CBL_new", "JA_PPO_CBL") | |
# | |
# label = label.replace( | |
# "08-12_dummy_cs_emulation_no_distr_rec_5_CB_exploration-bonus-type_cell_exploration-bonus-params__1_50", | |
# "PPO_CB_1") | |
# label = label.replace( | |
# "08-12_dummy_cs_emulation_no_distr_rec_5_CB_exploration-bonus-type_cell_exploration-bonus-params__1_50", | |
# "PPO_CB_1") | |
# | |
# label = label.replace("dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_exp_soc", "PPO_CB_role_B_single") | |
# label = label.replace("dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_contr_asoc", "PPO_CB_asoc_single") | |
# | |
# label = label.replace("dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_exp_soc", "PPO_CB_role_B_group") | |
# label = label.replace("dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_contr_asoc", "PPO_CB_asoc_group") | |
# | |
# label = label.replace( | |
# "03-12_dummy_cs_formats_CBL_A_rec_5_env_SocialAI-ALangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AFormatsTestSet_exploration-bonus-type_lang", | |
# "PPO_CBL_Ask") | |
# label = label.replace( | |
# "03-12_dummy_cs_formats_CBL_E_rec_5_env_SocialAI-ELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_EFormatsTestSet_exploration-bonus-type_lang", | |
# "PPO_CBL_Eye_contact") | |
# label = label.replace( | |
# "03-12_dummy_cs_formats_CBL_AE_rec_5_env_SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AEFormatsTestSet_exploration-bonus-type_lang", | |
# "PPO_CBL_Ask_Eye_contact") | |
# label = label.replace("dummy_cs_formats_CBL_N_rec_5", "PPO_CBL_No") | |
# | |
# label = label.replace( | |
# "03-12_dummy_cs_jz_formats_E_rec_5_env_SocialAI-ELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_EFormatsTestSet", | |
# "PPO_no_bonus_Eye_contact") | |
# label = label.replace( | |
# "03-12_dummy_cs_jz_formats_A_rec_5_env_SocialAI-ALangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AFormatsTestSet", | |
# "PPO_no_bonus_Ask") | |
# label = label.replace( | |
# "03-12_dummy_cs_jz_formats_AE_rec_5_env_SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AEFormatsTestSet", | |
# "PPO_no_bonus_Ask_Eye_contact") | |
# label = label.replace("dummy_cs_jz_formats_N_rec_5", "PPO_no_bonus_No") | |
# | |
# label = label.replace("03-12_dummy_cs_jz_formats_AE", "PPO_no_bonus_no_scaf") | |
# label = label.replace("dummy_cs_jz_scaf_A_E_N_A_E_full-AEfull", "PPO_no_bonus_scaf_4") | |
# label = label.replace("dummy_cs_jz_scaf_A_E_N_A_E_scaf_full-AEfull", "PPO_no_bonus_scaf_8") | |
# | |
# return label | |
# Final case studies | |
require_patterns = [ | |
"_", | |
# pointing | |
# "04-01_Pointing_CB_heldout_doors", | |
# # role reversal | |
# "03-01_RR_ft_single_CB_marble_pass_A_asoc_contr", | |
# "03-01_RR_ft_single_CB_marble_pass_A_soc_exp", | |
# "05-01_RR_ft_group_50M_CB_marble_pass_A_asoc_contr", | |
# "05-01_RR_ft_group_50M_CB_marble_pass_A_soc_exp", | |
# scaffolding | |
# "05-01_scaffolding_50M_no", | |
# "05-01_scaffolding_50M_acl_4_acl-type_intro_seq", | |
# "05-01_scaffolding_50M_acl_8_acl-type_intro_seq_scaf", | |
] | |
def label_parser(label, figure_id, label_parser_dict=None): | |
label = label.replace("04-01_Pointing_CB_heldout_doors", "PPO_CB") | |
label = label.replace("05-01_scaffolding_50M_no_acl", "PPO_no_scaf") | |
label = label.replace("05-01_scaffolding_50M_acl_4_acl-type_intro_seq", "PPO_scaf_4") | |
label = label.replace("05-01_scaffolding_50M_acl_8_acl-type_intro_seq_scaf", "PPO_scaf_8") | |
label = label.replace("03-01_RR_ft_single_CB_marble_pass_A_soc_exp", "PPO_CB_role_B") | |
label = label.replace("03-01_RR_ft_single_CB_marble_pass_A_asoc_contr", "PPO_CB_asocial") | |
label = label.replace("05-01_RR_ft_group_50M_CB_marble_pass_A_soc_exp", "PPO_CB_role_B") | |
label = label.replace("05-01_RR_ft_group_50M_CB_marble_pass_A_asoc_contr", "PPO_CB_asocial") | |
return label | |
color_dict = { | |
# JA | |
# "JA_PPO_CBL": "blue", | |
# "PPO_CBL": "orange", | |
# RR group | |
# "PPO_CB_role_B_group": "orange", | |
# "PPO_CB_asoc_group": "blue" | |
# formats No | |
# "PPO_no_bonus_No": "blue", | |
# "PPO_no_bonus_Eye_contact": "magenta", | |
# "PPO_no_bonus_Ask": "orange", | |
# "PPO_no_bonus_Ask_Eye_contact": "green" | |
# formats CBL | |
# "PPO_CBL_No": "blue", | |
# "PPO_CBL_Eye_contact": "magenta", | |
# "PPO_CBL_Ask": "orange", | |
# "PPO_CBL_Ask_Eye_contact": "green" | |
} | |
# # POINTING_GENERALIZATION (DUMMY) | |
# require_patterns = [ | |
# "29-10_SAI_Pointing_CS_PPO_CB_", | |
# "29-10_SAI_LangColor_CS_PPO_CB_" | |
# ] | |
# | |
# color_dict = { | |
# "dummy_cs_JA_Feedback_CBL_new": "blue", | |
# "dummy_cs_Feedback_CBL": "orange", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("Pointing_CS_PPO_CB", "PPO_CB_train(DUMMY)") | |
# label=label.replace("LangColor_CS_PPO_CB", "PPO_CB_test(DUMMY)") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Pointing_gen_eval.png" | |
# # FEEDBACK GENERALIZATION (DUMMY) | |
# require_patterns = [ | |
# "29-10_SAI_LangFeedback_CS_PPO_CBL_", | |
# "29-10_SAI_LangColor_CS_PPO_CB_" | |
# ] | |
# | |
# color_dict = { | |
# "PPO_CBL_train(DUMMY)": "blue", | |
# "PPO_CBL_test(DUMMY)": "maroon", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangFeedback_CS_PPO_CBL", "PPO_CBL_train(DUMMY)") | |
# label=label.replace("LangColor_CS_PPO_CB", "PPO_CBL_test(DUMMY)") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Feedback_gen_eval.png" | |
# # COLOR GENERALIZATION (DUMMY) | |
# require_patterns = [ | |
# "29-10_SAI_LangColor_CS_PPO_CBL_", | |
# "29-10_SAI_LangColor_CS_PPO_CB_" | |
# ] | |
# | |
# color_dict = { | |
# "PPO_CBL_train(DUMMY)": "blue", | |
# "PPO_CBL_test(DUMMY)": "maroon", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangColor_CS_PPO_CBL", "PPO_CBL_train(DUMMY)") | |
# label=label.replace("LangColor_CS_PPO_CB", "PPO_CBL_test(DUMMY)") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Color_gen_eval.png" | |
# # POINTING - PILOT | |
# require_patterns = [ | |
# "29-10_SAI_Pointing_CS_PPO_", | |
# ] | |
# | |
# color_dict = { | |
# "PPO_RIDE": "orange", | |
# "PPO_RND": "magenta", | |
# "PPO_no": "maroon", | |
# "PPO_CBL": "green", | |
# "PPO_CB": "blue", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("Pointing_CS_", "") | |
# return label | |
# # | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Pointing_eval.png" | |
# LANGCOLOR - 7 Colors - PILOT | |
# require_patterns = [ | |
# "29-10_SAI_LangColor_CS_PPO_", | |
# ] | |
# | |
# color_dict = { | |
# "PPO_RIDE": "orange", | |
# "PPO_RND": "magenta", | |
# "PPO_no": "maroon", | |
# "PPO_CBL": "green", | |
# "PPO_CB": "blue", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangColor_CS_", "") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Color_eval.png" | |
# # LangColor - CBL - 3 5 7 | |
# require_patterns = [ | |
# "02-11_SAI_LangColor_CS_5C_PPO_CBL", | |
# "02-11_SAI_LangColor_CS_3C_PPO_CBL", | |
# "29-10_SAI_LangColor_CS_PPO_CBL" | |
# ] | |
# RND RIDE reference : RIDE > RND > no | |
# require_patterns = [ | |
# "24-08_new_ref", | |
# ] | |
# # # LANG FEEDBACK | |
# require_patterns = [ | |
# "24-10_SAI_LangFeedback_CS_PPO_", | |
# "29-10_SAI_LangFeedback_CS_PPO_", | |
# ] | |
# color_dict = { | |
# "PPO_RIDE": "orange", | |
# "PPO_RND": "magenta", | |
# "PPO_no": "maroon", | |
# "PPO_CBL": "green", | |
# "PPO_CB": "blue", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangFeedback_CS_", "") | |
# return label | |
# | |
# # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Feedback_eval.png" | |
# | |
# # ROLE REVERSAL - group (DUMMY) | |
# require_patterns = [ | |
# "24-10_SAI_LangFeedback_CS_PPO_CB_", | |
# "29-10_SAI_LangFeedback_CS_PPO_CBL_", | |
# ] | |
# color_dict = { | |
# "PPO_CB_experimental": "green", | |
# "PPO_CB_control": "blue", | |
# } | |
# color_dict=None | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangFeedback_CS_", "") | |
# | |
# label=label.replace("PPO_CB", "PPO_CB_control") | |
# label=label.replace("controlL", "experimental") | |
# | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/RR_dummy_group.png" | |
# # ROLE REVERSAL - single (DUMMY) | |
# require_patterns = [ | |
# "24-10_SAI_LangFeedback_CS_PPO_CB_", | |
# "24-10_SAI_LangFeedback_CS_PPO_no_", | |
# ] | |
# color_dict = { | |
# "PPO_CB_experimental": "green", | |
# "PPO_CB_control": "blue", | |
# } | |
# color_dict=None | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangFeedback_CS_", "") | |
# | |
# label=label.replace("PPO_CB", "PPO_CB_control") | |
# label=label.replace("PPO_no", "PPO_CB_experimental") | |
# | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/RR_dummy_single.png" | |
# # IMITATION train (DUMMY) | |
# require_patterns = [ | |
# "29-10_SAI_LangFeedback_CS_PPO_CBL_", | |
# "29-10_SAI_Pointing_CS_PPO_RIDE", | |
# ] | |
# | |
# color_dict = { | |
# "PPO_CB_no_distr(DUMMY)": "magenta", | |
# "PPO_CB_distr(DUMMY)": "orange", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangFeedback_CS_PPO_CBL", "PPO_CB_no_distr(DUMMY)") | |
# label=label.replace("Pointing_CS_PPO_RIDE", "PPO_CB_distr(DUMMY)") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Imitation_train.png" | |
# # IMITATION test (DUMMY) | |
# require_patterns = [ | |
# "29-10_SAI_LangFeedback_CS_PPO_CBL_", | |
# "29-10_SAI_Pointing_CS_PPO_RIDE", | |
# ] | |
# | |
# color_dict = { | |
# "PPO_CB_no_distr(DUMMY)": "magenta", | |
# "PPO_CB_distr(DUMMY)": "orange", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("LangFeedback_CS_PPO_CBL", "PPO_CB_no_distr(DUMMY)") | |
# label=label.replace("Pointing_CS_PPO_RIDE", "PPO_CB_distr(DUMMY)") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Imitation_test.png" | |
# JA_POINTING | |
# require_patterns = [ | |
# "29-10_SAI_Pointing_CS_PPO_CB_", | |
# "04-11_SAI_JA_Pointing_CS_PPO_CB_less", # less reward | |
# ] | |
# color_dict = { | |
# "JA_Pointing_PPO_CB": "orange", | |
# "Pointing_PPO_CB": "blue", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("_CS_", "_") | |
# label=label.replace("_less_", "") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/JA_Pointing_eval.png" | |
# # JA_COLORS (JA, no) x (3,5,7) | |
# max_x_lim = 17 | |
# require_patterns = [ | |
# # "02-11_SAI_JA_LangColor", # max_x_lim = 17 | |
# "02-11_SAI_JA_LangColor_CS_3C", # max_x_lim = 17 | |
# # "02-11_SAI_LangColor_CS_5C_PPO_CBL", # max_x_lim = 17 | |
# "02-11_SAI_LangColor_CS_3C_PPO_CBL", | |
# # "29-10_SAI_LangColor_CS_PPO_CBL" | |
# ] | |
# color_dict = { | |
# "JA_LangColor_PPO_CBL": "orange", | |
# "LangColor_PPO_CBL": "blue", | |
# } | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("_CS_", "_") | |
# label=label.replace("_3C_", "_") | |
# return label | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/JA_Color_eval.png" | |
# JA_FEEDBACK -> max_xlim=17 | |
# max_x_lim = 17 | |
# require_patterns = [ | |
# "02-11_SAI_JA_LangFeedback_CS_PPO_CBL_", | |
# "29-10_SAI_LangFeedback_CS_PPO_CBL_", | |
# "dummy_cs_F", | |
# "dummy_cs_JA_F" | |
# ] | |
# color_dict = { | |
# "JA_LangFeedback_PPO_CBL": "orange", | |
# "LangFeedback_PPO_CBL": "blue", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("_CS_", "_") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/JA_Feedback_eval.png" | |
# # Formats CBL | |
# require_patterns = [ | |
# "03-11_SAI_LangFeedback_CS_F_NO_PPO_CBL_env_SocialAI", | |
# "29-10_SAI_LangFeedback_CS_PPO_CBL_env_SocialAI", | |
# "03-11_SAI_LangFeedback_CS_F_ASK_PPO_CBL_env_SocialAI", | |
# "03-11_SAI_LangFeedback_CS_F_ASK_EYE_PPO_CBL_env_SocialAI", | |
# ] | |
# color_dict = { | |
# "LangFeedback_Eye_PPO_CBL": "blue", | |
# "LangFeedback_Ask_PPO_CBL": "orange", | |
# "LangFeedback_NO_PPO_CBL": "green", | |
# "LangFeedback_AskEye_PPO_CBL": "magenta", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("_CS_", "_") | |
# label=label.replace("_F_", "_") | |
# | |
# label=label.replace("LangFeedback_PPO", "LangFeedback_EYE_PPO") | |
# | |
# label=label.replace("EYE", "Eye") | |
# label=label.replace("No", "No") | |
# label=label.replace("ASK", "Ask") | |
# label=label.replace("Ask_Eye", "AskEye") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Formats_CBL_eval.png" | |
# # Formats NO | |
# require_patterns = [ | |
# "24-10_SAI_LangFeedback_CS_PPO_no", # EYE | |
# "04-11_SAI_LangFeedback_CS_F_NO_PPO_NO_env_SocialAI", | |
# "04-11_SAI_LangFeedback_CS_F_ASK_PPO_NO_env_SocialAI", | |
# "04-11_SAI_LangFeedback_CS_F_ASK_EYE_PPO_NO_env_SocialAI", | |
# ] | |
# | |
# color_dict = { | |
# "LangFeedback_Eye_PPO_no": "blue", | |
# "LangFeedback_Ask_PPO_no": "orange", | |
# "LangFeedback_NO_PPO_no": "green", | |
# "LangFeedback_AskEye_PPO_no": "magenta", | |
# } | |
# | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# label=label.replace("_CS_", "_") | |
# label=label.replace("_F_", "_") | |
# # | |
# label=label.replace("LangFeedback_PPO", "LangFeedback_EYE_PPO") | |
# label=label.replace("PPO_NO", "PPO_no") | |
# | |
# label=label.replace("EYE", "Eye") | |
# label=label.replace("No", "No") | |
# label=label.replace("ASK", "Ask") | |
# label=label.replace("Ask_Eye", "AskEye") | |
# return label | |
# | |
# eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Formats_no_eval.png" | |
# | |
# require_patterns = [ | |
# "11-07_bAI_cb_GS_param_tanh_env_SocialAI-SocialAIParamEnv-v1_exploration-bonus-type_cell_exploration-bonus-params__2_50_exploration-bonus-tanh_0.6", | |
# # "04-11_SAI_ImitationDistr_CS_PPO_CB_small_env_SocialAI-EEmulationDistrInformationSeekingParamEnv-v1_recurrence_10", | |
# # "04-11_SAI_ImitationDistr_CS_PPO_CB_small_env_SocialAI-EEmulationDistrInformationSeekingParamEnv-v1_recurrence_10", | |
# "03-11_SAI_ImitationDistr_CS_PPO_CB_env_SocialAI-EEmulationDistrInformationSeekingParamEnv-v1_recurrence_10", | |
# # "04-11_SAI_ImitationNoDistr_CS_PPO_CB_small_env_SocialAI-EEmulationNoDistrInformationSeekingParamEnv-v1_recurrence_10", | |
# ] | |
# require_patterns = [ | |
# "02-11_SAI_LangColor_CS_3C_PPO_CBL", | |
# "02-11_SAI_JA_LangColor_CS_3C_PPO_CBL", | |
# ] # at least one of those | |
# all of those | |
include_patterns = [ | |
"_" | |
] | |
#include_patterns = ["rec_5"] | |
if eval_filename: | |
# saving | |
fontsize = 40 | |
legend_fontsize = 30 | |
linewidth = 10 | |
else: | |
fontsize = 5 | |
legend_fontsize = 5 | |
linewidth = 1 | |
fontsize = 5 | |
legend_fontsize = 5 | |
linewidth = 1 | |
title_fontsize = int(fontsize*1.2) | |
storage_dir = "storage/" | |
if load_pattern.startswith(storage_dir): | |
load_pattern = load_pattern[len(storage_dir):] | |
if load_pattern.startswith("./storage/"): | |
load_pattern = load_pattern[len("./storage/"):] | |
get_datasets(storage_dir, str(load_pattern), load_subsample_step=load_subsample_step, ignore_patterns=ignore_patterns, require_patterns=require_patterns) | |
label_parser_dict = { | |
# "PPO_CB": "PPO_CB", | |
# "02-06_AppleStealing_experiments_cb_bonus_angle_occ_env_SocialAI-OthersPerceptionInferenceParamEnv-v1_exploration-bonus-type_cell": "NPC_visible", | |
} | |
env_type = str(load_pattern) | |
fig_type = "test" | |
try: | |
top_n = int(sys.argv[2]) | |
except: | |
top_n = 8 | |
to_remove = [] | |
for tr_ in to_remove: | |
if tr_ in models_saves: | |
del models_saves[tr_] | |
print("Loaded:") | |
print("\n".join(list(models_saves.keys()))) | |
#### get_datasets("storage/", "RERUN_WizardGuide_lang64_nameless") | |
#### get_datasets("storage/", "RERUN_WizardTwoGuides_lang64_nameless") | |
if per_model_colors: # order runs for legend order as in per_models_colors, with corresponding colors | |
ordered_labels = OrderedDict() | |
for teacher_type in per_model_colors.keys(): | |
for k,v in labels.items(): | |
if teacher_type in k: | |
ordered_labels[k] = v | |
labels = ordered_labels | |
else: | |
print('not using per_model_color') | |
for k in models_saves.keys(): | |
labels[k] = k | |
def plot_with_shade_seed(subplot_nb, ax, x, y, err, color, shade_color, label, | |
y_min=None, y_max=None, legend=False, leg_size=30, leg_loc='best', title=None, | |
ylim=[0,100], xlim=[0,40], leg_args={}, leg_linewidth=13.0, linewidth=10.0, labelsize=20, | |
filename=None, | |
zorder=None, xlabel='perf', ylabel='Env steps'): | |
plt.rcParams.update({'font.size': 15}) | |
plt.rcParams['axes.xmargin'] = 0 | |
plt.rcParams['axes.ymargin'] = 0 | |
ax.locator_params(axis='x', nbins=3) | |
ax.locator_params(axis='y', nbins=3) | |
ax.tick_params(axis='both', which='major', labelsize=labelsize) | |
x = x[:len(y)] | |
# ax.scatter(x, y, color=color, linewidth=linewidth, zorder=zorder) | |
ax.plot(x, y, color=color, label=label, linewidth=linewidth, zorder=zorder) | |
if err is not None: | |
ax.fill_between(x, y-err, y+err, color=shade_color, alpha=0.2) | |
if legend: | |
leg = ax.legend(loc=leg_loc, **leg_args) #34 | |
for legobj in leg.legendHandles: | |
legobj.set_linewidth(leg_linewidth) | |
ax.set_xlabel(xlabel, fontsize=fontsize) | |
if subplot_nb == 0: | |
ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=4) | |
ax.set_xlim(xmin=xlim[0],xmax=xlim[1]) | |
ax.set_ylim(bottom=ylim[0],top=ylim[1]) | |
if title: | |
ax.set_title(title, fontsize=fontsize) | |
# if filename is not None: | |
# f.savefig(filename) | |
# Plot utils | |
def plot_with_shade_grg(subplot_nb, ax, x, y, err, color, shade_color, label, | |
legend=False, leg_loc='best', title=None, | |
ylim=[0, 100], xlim=[0, 40], leg_args={}, leg_linewidth=13.0, linewidth=10.0, labelsize=20, fontsize=20, title_fontsize=30, | |
zorder=None, xlabel='Perf', ylabel='Env steps', linestyle="-", xnbins=3, ynbins=3, filename=None): | |
#plt.rcParams.update({'font.size': 15}) | |
ax.locator_params(axis='x', nbins=xnbins) | |
ax.locator_params(axis='y', nbins=ynbins) | |
ax.tick_params(axis='y', which='both', labelsize=labelsize) | |
ax.tick_params(axis='x', which='both', labelsize=labelsize*0.8) | |
# ax.tick_params(axis='both', which='both', labelsize="small") | |
# ax.scatter(x, y, color=color,linewidth=linewidth,zorder=zorder, linestyle=linestyle) | |
ax.plot(x, y, color=color, label=label, linewidth=linewidth, zorder=zorder, linestyle=linestyle) | |
ax.fill_between(x, y-err, y+err, color=shade_color, alpha=0.2) | |
if legend: | |
leg = ax.legend(loc=leg_loc, **leg_args) # 34 | |
for legobj in leg.legendHandles: | |
legobj.set_linewidth(leg_linewidth) | |
ax.set_xlabel(xlabel, fontsize=fontsize) | |
if subplot_nb == 0: | |
ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=2) | |
ax.set_xlim(xmin=xlim[0], xmax=xlim[1]) | |
ax.set_ylim(bottom=ylim[0], top=ylim[1]) | |
if title: | |
ax.set_title(title, fontsize=title_fontsize) | |
# if filename is not None: | |
# f.savefig(filename) | |
# Metric plot | |
# metric = 'success_rate_mean' | |
# metric = 'mission_string_observed_mean' | |
# metric = 'extrinsic_return_mean' | |
# metric = 'extrinsic_return_max' | |
# metric = "rreturn_mean" | |
# metric = 'rreturn_max' | |
# metric = 'FPS' | |
# metric = 'duration' | |
# metric = 'intrinsic_reward_perf2_' | |
# metric = 'NPC_intro' | |
metrics = [ | |
'success_rate_mean', | |
# 'FPS', | |
# 'extrinsic_return_mean', | |
# 'exploration_bonus_mean', | |
'NPC_intro', | |
# 'curriculum_param_mean', | |
# 'curriculum_max_success_rate_mean', | |
# 'rreturn_mean' | |
] | |
# f, ax = plt.subplots(1, len(metrics), figsize=(15.0, 9.0)) | |
f, ax = plt.subplots(1, len(metrics), figsize=(9.0, 9.0)) | |
# f, ax = plt.subplots(1, len(metrics), figsize=(20.0, 20.0)) | |
# f, ax = plt.subplots(1, 1, figsize=(5.0, 3.0)) | |
if len(metrics) == 1: | |
ax = [ax] | |
max_y = -np.inf | |
min_y = np.inf | |
# hardcoded | |
min_y, max_y = 0.0, 1.0 | |
max_steps = 0 | |
exclude_patterns = [] | |
# def label_parser(label, figure_id, label_parser_dict=None): | |
# | |
# label = label.split("_env_")[0].split("SAI_")[1] | |
# | |
# # # Pointing | |
# # label=label.replace("Pointing_CS_", "") | |
# | |
# # Feedback | |
# label=label.replace("LangFeedback_CS_", "") | |
# | |
# | |
# # label=label.replace("CS_PPO", "7COL_PPO") | |
# # label=label.replace("CS_3C_PPO", "3COL_PPO") | |
# # label=label.replace("CS_5C_PPO", "5COL_PPO") | |
# | |
# # label=label.replace("CS_PPO", "Eye_contact_PPO") | |
# # label=label.replace("CS_F_ASK_PPO", "Ask_PPO") | |
# # label=label.replace("CS_F_NO_PPO", "NO_PPO") | |
# # label=label.replace("CS_F_ASK_EYE_PPO", "Ask_Eye_contact_PPO") | |
# # | |
# # label=label.replace("PPO_no", "PPO_no_bonus") | |
# # label=label.replace("PPO_NO", "PPO_no_bonus") | |
# | |
# if label_parser_dict: | |
# if sum([1 for k, v in label_parser_dict.items() if k in label]) != 1: | |
# if label in label_parser_dict: | |
# # see if there is an exact match | |
# return label_parser_dict[label] | |
# else: | |
# print("ERROR multiple curves match a lable and there is no exact match for {}".format(label)) | |
# exit() | |
# | |
# for k, v in label_parser_dict.items(): | |
# if k in label: return v | |
# | |
# else: | |
# # return label.split("_env_")[1] | |
# if figure_id not in [1, 2, 3, 4]: | |
# return label | |
# else: | |
# # default | |
# pass | |
# | |
# return label | |
for metric_i, metric in enumerate(metrics): | |
min_y, max_y = 0.0, 1.0 | |
default_colors = default_colors_.copy() | |
for model_i, m_id in enumerate(models_saves.keys()): | |
#excluding some experiments | |
if any([ex_pat in m_id for ex_pat in exclude_patterns]): | |
continue | |
if len(include_patterns) > 0: | |
if not any([in_pat in m_id for in_pat in include_patterns]): | |
continue | |
runs_data = models_saves[m_id]['data'] | |
ys = [] | |
if runs_data[0]['frames'][1] == 'frames': | |
runs_data[0]['frames'] = list(filter(('frames').__ne__, runs_data[0]['frames'])) | |
########################################### | |
if per_seed: | |
min_len = None | |
else: | |
# determine minimal run length across seeds | |
lens = [len(run['frames']) for run in runs_data if len(run['frames'])] | |
minimum = sorted(lens)[-min(top_n, len(lens))] | |
min_len = np.min([len(run['frames']) for run in runs_data if len(run['frames']) >= minimum]) | |
# keep only top k | |
runs_data = [run for run in runs_data if len(run['frames']) >= minimum] | |
# min_len = np.min([len(run['frames']) for run in runs_data if len(run['frames']) > 10]) | |
# compute env steps (x axis) | |
longest_id = np.argmax([len(rd['frames']) for rd in runs_data]) | |
steps = np.array(runs_data[longest_id]['frames'], dtype=np.int) / steps_denom | |
steps = steps[:min_len] | |
for run in runs_data: | |
if metric not in run: | |
# succes_rate_mean <==> bin_extrinsic_return_mean | |
if metric == 'success_rate_mean': | |
metric_ = "bin_extrinsic_return_mean" | |
if metric_ not in run: | |
raise ValueError("Neither {} or {} is present: {} Possible metrics: {}. ".format(metric, metric_, list(run.keys()))) | |
data = run[metric_] | |
else: | |
raise ValueError("Unknown metric: {} Possible metrics: {}. ".format(metric, list(run.keys()))) | |
else: | |
data = run[metric] | |
if data[1] == metric: | |
data = np.array(list(filter((metric).__ne__, data)), dtype=np.float16) | |
########################################### | |
if per_seed: | |
ys.append(data) | |
else: | |
if len(data) >= min_len: | |
if len(data) > min_len: | |
print("run has too many {} datapoints ({}). Discarding {}".format(m_id, len(data), | |
len(data)-min_len)) | |
data = data[0:min_len] | |
ys.append(data) | |
else: | |
raise ValueError("How can data be < min_len if it was capped above") | |
ys_same_len = ys | |
# computes stats | |
n_seeds = len(ys_same_len) | |
if per_seed: | |
sems = np.array(ys_same_len) | |
stds = np.array(ys_same_len) | |
means = np.array(ys_same_len) | |
color = default_colors[model_i] | |
else: | |
sems = np.std(ys_same_len, axis=0)/np.sqrt(len(ys_same_len)) # sem | |
stds = np.std(ys_same_len, axis=0) # std | |
means = np.mean(ys_same_len, axis=0) | |
color = default_colors[model_i] | |
# per-metric adjustments | |
ylabel = metric | |
ylabel = { | |
"success_rate_mean" : "Success rate", | |
"exploration_bonus_mean": "Exploration bonus", | |
"NPC_intro": "Successful introduction (%)", | |
}.get(ylabel, ylabel) | |
if metric == 'duration': | |
ylabel = "time (hours)" | |
means = means / 3600 | |
sems = sems / 3600 | |
stds = stds / 3600 | |
if per_seed: | |
#plot x y bounds | |
curr_max_y = np.max(np.max(means)) | |
curr_min_y = np.min(np.min(means)) | |
curr_max_steps = np.max(np.max(steps)) | |
else: | |
# plot x y bounds | |
curr_max_y = np.max(means+stds) | |
curr_min_y = np.min(means-stds) | |
curr_max_steps = np.max(steps) | |
if curr_max_y > max_y: | |
max_y = curr_max_y | |
if curr_min_y < min_y: | |
min_y = curr_min_y | |
if curr_max_steps > max_steps: | |
max_steps = curr_max_steps | |
if subsample_step: | |
steps = steps[0::subsample_step] | |
means = means[0::subsample_step] | |
stds = stds[0::subsample_step] | |
sems = sems[0::subsample_step] | |
ys_same_len = [y[0::subsample_step] for y in ys_same_len] | |
# display seeds separtely | |
if per_seed: | |
for s_i, seed_ys in enumerate(ys_same_len): | |
seed_c = default_colors[model_i+s_i] | |
# label = m_id#+"(s:{})".format(s_i) | |
label = str(s_i) | |
seed_ys = smooth(seed_ys, smooth_factor) | |
plot_with_shade_seed(0, ax[metric_i], steps, seed_ys, None, seed_c, seed_c, label, | |
legend=draw_legend, xlim=[0, max_steps], ylim=[min_y, max_y], | |
leg_size=leg_size, xlabel=f"Env steps (1e6)", ylabel=ylabel, linewidth=linewidth, | |
labelsize=fontsize, | |
# fontsize=fontsize, | |
) | |
summary_dict[s_i] = seed_ys[-1] | |
summary_dict_colors[s_i] = seed_c | |
else: | |
label = label_parser(m_id, load_pattern, label_parser_dict=label_parser_dict) | |
if color_dict: | |
color = color_dict[label] | |
else: | |
color = default_colors[model_i] | |
label = label+"({})".format(n_seeds) | |
if smooth_factor: | |
means = smooth(means, smooth_factor) | |
stds = smooth(stds, smooth_factor) | |
x_lim = max(steps[-1], x_lim) | |
x_lim = min(max_x_lim, x_lim) | |
leg_args = { | |
'fontsize': legend_fontsize | |
} | |
plot_with_shade_grg( | |
0, ax[metric_i], steps, means, stds, color, color, label, | |
legend=draw_legend and metric_i == 0, | |
xlim=[0, x_lim], | |
ylim=[0, max_y], | |
xlabel=f"Env steps (1e6)", | |
ylabel=ylabel, | |
title=None, | |
labelsize=fontsize*train_inc_font, | |
fontsize=fontsize*train_inc_font, | |
title_fontsize=title_fontsize, | |
linewidth=linewidth, | |
leg_linewidth=5, | |
leg_args=leg_args, | |
xnbins=xnbins, | |
ynbins=ynbins, | |
) | |
summary_dict[label] = means[-1] | |
summary_dict_colors[label] = color | |
if len(summary_dict) == 0: | |
raise ValueError(f"No experiments found for {load_pattern}.") | |
# print summary | |
best = max(summary_dict.values()) | |
pc = 0.3 | |
n = int(len(summary_dict)*pc) | |
print("top n: ", n) | |
top_pc = sorted(summary_dict.values())[-n:] | |
bottom_pc = sorted(summary_dict.values())[:n] | |
print("legend:") | |
cprint("\tbest", "green") | |
cprint("\ttop {} %".format(pc), "blue") | |
cprint("\tbottom {} %".format(pc), "red") | |
print("\tothers") | |
print() | |
for l, p in sorted(summary_dict.items(), key=lambda kv: kv[1]): | |
c = summary_dict_colors[l] | |
if p == best: | |
cprint("label: {} ({})".format(l, c), "green") | |
cprint("\t {}:{}".format(metric, p), "green") | |
elif p in top_pc: | |
cprint("label: {} ({})".format(l, c), "blue") | |
cprint("\t {}:{}".format(metric, p), "blue") | |
elif p in bottom_pc: | |
cprint("label: {} ({})".format(l, c), "red") | |
cprint("\t {}:{}".format(metric, p), "red") | |
else: | |
print("label: {} ({})".format(l, c)) | |
print("\t {}:{}".format(metric, p)) | |
for label, (mean, std, color) in static_lines.items(): | |
plot_with_shade_grg( | |
0, ax[metric_i], steps, np.array([mean]*len(steps)), np.array([std]*len(steps)), color, color, label, | |
legend=True, | |
xlim=[0, x_lim], | |
ylim=[0, 1.0], | |
xlabel=f"Env steps (1e6)", | |
ylabel=ylabel, | |
linestyle=":", | |
leg_args=leg_args, | |
fontsize=fontsize, | |
title_fontsize=title_fontsize, | |
xnbins=xnbins, | |
ynbins=ynbins, | |
) | |
# plt.tight_layout() | |
# f.savefig('graphics/{}_{}_results.svg'.format(str(figure_id, metric))) | |
# f.savefig('graphics/{}_{}_results.png'.format(str(figure_id, metric))) | |
cprint("Ignore pattern: {}".format(ignore_patterns), "blue") | |
if plot_train: | |
plt.tight_layout() | |
# plt.subplots_adjust(hspace=1.5, wspace=0.5, left=0.1, right=0.9, bottom=0.1, top=0.85) | |
plt.subplots_adjust(hspace=1.5, wspace=0.5, left=0.1, right=0.9, bottom=0.1, top=0.85) | |
plt.suptitle(super_title) | |
plt.show() | |
plt.close() | |
curr_max_y = 0 | |
x_lim = 0 | |
max_y = -np.inf | |
min_y = np.inf | |
# hardcoded | |
min_y, max_y = 0.0, 1.0 | |
grid = True | |
draw_eval_legend = True | |
if study_eval: | |
print("Evaluation") | |
# evaluation sets | |
number_of_eval_envs = max(list([len(v.keys()) for v in model_eval_data.values()])) | |
if plot_aggregated_test: | |
number_of_eval_envs += 1 | |
if number_of_eval_envs == 0: | |
print("No eval envs") | |
exit() | |
if plot_only_aggregated_test: | |
f, ax = plt.subplots(1, 1, figsize=(9.0, 9.0)) | |
else: | |
if grid: | |
# grid | |
subplot_y = math.ceil(math.sqrt(number_of_eval_envs)) | |
subplot_x = math.ceil(number_of_eval_envs / subplot_y) | |
# from IPython import embed; embed() | |
while subplot_x % 1 != 0: | |
subplot_y -= 1 | |
subplot_x = number_of_eval_envs / subplot_y | |
if subplot_x == 1: | |
subplot_y = math.ceil(math.sqrt(number_of_eval_envs)) | |
subplot_x = math.floor(math.sqrt(number_of_eval_envs)) | |
subplot_y = int(subplot_y) | |
subplot_x = int(subplot_x) | |
assert subplot_y * subplot_x >= number_of_eval_envs | |
f, ax_ = plt.subplots(subplot_y, subplot_x, figsize=(6.0, 6.0), sharey=False) #, sharex=True, sharey=True) | |
if subplot_y != 1: | |
ax = list(chain.from_iterable(ax_)) | |
else: | |
ax=ax_ | |
else: | |
# flat | |
f, ax = plt.subplots(1, number_of_eval_envs, figsize=(15.0, 9.0)) #), sharey=True, sharex=True) | |
if number_of_eval_envs == 1: | |
ax = [ax] | |
default_colors = default_colors_.copy() | |
test_summary_dict = defaultdict(dict) | |
test_summary_dict_colors = defaultdict(dict) | |
for model_i, m_id in enumerate(model_eval_data.keys()): | |
# excluding some experiments | |
if any([ex_pat in m_id for ex_pat in exclude_patterns]): | |
continue | |
if len(include_patterns) > 0: | |
if not any([in_pat in m_id for in_pat in include_patterns]): | |
continue | |
# computes stats | |
if sort_test: | |
test_envs_sorted = enumerate(sorted(model_eval_data[m_id].items(), key=lambda kv: sort_test_set(kv[0]))) | |
else: | |
test_envs_sorted = enumerate(model_eval_data[m_id].items()) | |
if plot_aggregated_test: | |
agg_means = [] | |
for env_i, (test_env, env_data) in test_envs_sorted: | |
ys_same_len = env_data["values"] | |
steps = env_data["steps"].mean(0) / steps_denom | |
n_seeds = len(ys_same_len) | |
if per_seed: | |
sems = np.array(ys_same_len) | |
stds = np.array(ys_same_len) | |
means = np.array(ys_same_len) | |
color = default_colors[model_i] | |
else: | |
sems = np.std(ys_same_len, axis=0) / np.sqrt(len(ys_same_len)) # sem | |
stds = np.std(ys_same_len, axis=0) # std | |
means = np.mean(ys_same_len, axis=0) | |
color = default_colors[model_i] | |
# per-metric adjusments | |
if per_seed: | |
# plot x y bounds | |
curr_max_y = np.max(np.max(means)) | |
curr_min_y = np.min(np.min(means)) | |
curr_max_steps = np.max(np.max(steps)) | |
else: | |
# plot x y bounds | |
curr_max_y = np.max(means + stds) | |
curr_min_y = np.min(means - stds) | |
curr_max_steps = np.max(steps) | |
if plot_aggregated_test: | |
agg_means.append(means) | |
if curr_max_y > max_y: | |
max_y = curr_max_y | |
if curr_min_y < min_y: | |
min_y = curr_min_y | |
x_lim = max(steps[-1], x_lim) | |
x_lim = min(max_x_lim, x_lim) | |
eval_metric_name = { | |
"test_success_rates": "Success rate", | |
'exploration_bonus_mean': "Exploration bonus", | |
}.get(eval_metric, eval_metric) | |
test_env_name = test_env.replace("Env", "").replace("Test", "") | |
env_types = ["InformationSeeking", "Collaboration", "PerspectiveTaking"] | |
for env_type in env_types: | |
if env_type in test_env_name: | |
test_env_name = test_env_name.replace(env_type, "") | |
test_env_name += f"\n({env_type})" | |
if grid: | |
ylabel = eval_metric_name | |
title = test_env_name | |
else: | |
# flat | |
ylabel = test_env_name | |
title = eval_metric_name | |
leg_args = { | |
'fontsize': legend_fontsize // 1 | |
} | |
if per_seed: | |
for s_i, seed_ys in enumerate(ys_same_len): | |
seed_c = default_colors[model_i + s_i] | |
# label = m_id#+"(s:{})".format(s_i) | |
label = str(s_i) | |
if not plot_only_aggregated_test: | |
seed_ys = smooth(seed_ys, eval_smooth_factor) | |
plot_with_shade_seed(0, ax[env_i], steps, seed_ys, None, seed_c, seed_c, label, | |
legend=draw_eval_legend, xlim=[0, x_lim], ylim=[min_y, max_y], | |
leg_size=leg_size, xlabel=f"Steps (1e6)", ylabel=ylabel, linewidth=linewidth, title=title) | |
test_summary_dict[s_i][test_env] = seed_ys[-1] | |
test_summary_dict_colors[s_i] = seed_c | |
else: | |
label = label_parser(m_id, load_pattern, label_parser_dict=label_parser_dict) | |
if not plot_only_aggregated_test: | |
if color_dict: | |
color = color_dict[label] | |
else: | |
color = default_colors[model_i] | |
label = label + "({})".format(n_seeds) | |
if smooth_factor: | |
means = smooth(means, eval_smooth_factor) | |
stds = smooth(stds, eval_smooth_factor) | |
plot_with_shade_grg( | |
0, ax[env_i], steps, means, stds, color, color, label, | |
legend=draw_eval_legend, | |
xlim=[0, x_lim+1], | |
ylim=[0, max_y], | |
xlabel=f"Env steps (1e6)" if env_i // (subplot_x) == subplot_y -1 else None, # only last line | |
ylabel=ylabel if env_i % subplot_x == 0 else None, # only first row | |
title=title, | |
title_fontsize=title_fontsize, | |
labelsize=fontsize, | |
fontsize=fontsize, | |
linewidth=linewidth, | |
leg_linewidth=5, | |
leg_args=leg_args, | |
xnbins=xnbins, | |
ynbins=ynbins, | |
) | |
test_summary_dict[label][test_env] = means[-1] | |
test_summary_dict_colors[label] = color | |
if plot_aggregated_test: | |
if plot_only_aggregated_test: | |
agg_env_i = 0 | |
else: | |
agg_env_i = number_of_eval_envs - 1 # last one | |
agg_means = np.array(agg_means) | |
agg_mean = agg_means.mean(axis=0) | |
agg_std = agg_means.std(axis=0) # std | |
if smooth_factor and not per_seed: | |
agg_mean = smooth(agg_mean, eval_smooth_factor) | |
agg_std = smooth(agg_std, eval_smooth_factor) | |
if color_dict: | |
color = color_dict[re.sub("\([0-9]\)", '', label)] | |
else: | |
color = default_colors[model_i] | |
if per_seed: | |
print("Not smooth aggregated because of per seed") | |
for s_i, (seed_ys, seed_st) in enumerate(zip(agg_mean, agg_std)): | |
seed_c = default_colors[model_i + s_i] | |
# label = m_id#+"(s:{})".format(s_i) | |
label = str(s_i) | |
# seed_ys = smooth(seed_ys, eval_smooth_factor) | |
plot_with_shade_seed(0, | |
ax if plot_only_aggregated_test else ax[agg_env_i], | |
steps, seed_ys, seed_st, seed_c, seed_c, label, | |
legend=draw_eval_legend, xlim=[0, x_lim], ylim=[min_y, max_y], | |
labelsize=fontsize, | |
filename=eval_filename, | |
leg_size=leg_size, xlabel=f"Steps (1e6)", ylabel=ylabel, linewidth=1, title=agg_title) | |
else: | |
# just used for creating a dummy Imitation test figure -> delete | |
# agg_mean = agg_mean * 0.1 | |
# agg_std = agg_std * 0.1 | |
# max_y = 1 | |
plot_with_shade_grg( | |
0, | |
ax if plot_only_aggregated_test else ax[agg_env_i], | |
steps, agg_mean, agg_std, color, color, label, | |
legend=draw_eval_legend, | |
xlim=[0, x_lim + 1], | |
ylim=[0, max_y], | |
xlabel=f"Steps (1e6)" if plot_only_aggregated_test or (agg_env_i // (subplot_x) == subplot_y - 1) else None, # only last line | |
ylabel=ylabel if plot_only_aggregated_test or (agg_env_i % subplot_x == 0) else None, # only first row | |
title_fontsize=title_fontsize, | |
title=agg_title, | |
labelsize=fontsize, | |
fontsize=fontsize, | |
linewidth=linewidth, | |
leg_linewidth=5, | |
leg_args=leg_args, | |
xnbins=xnbins, | |
ynbins=ynbins, | |
filename=eval_filename, | |
) | |
# print summary | |
means_dict = { | |
lab: np.array(list(lab_sd.values())).mean() for lab, lab_sd in test_summary_dict.items() | |
} | |
best = max(means_dict.values()) | |
pc = 0.3 | |
n = int(len(means_dict) * pc) | |
print("top n: ", n) | |
top_pc = sorted(means_dict.values())[-n:] | |
bottom_pc = sorted(means_dict.values())[:n] | |
print("Legend:") | |
cprint("\tbest", "green") | |
cprint("\ttop {} %".format(pc), "blue") | |
cprint("\tbottom {} %".format(pc), "red") | |
print("\tothers") | |
print() | |
for l, l_mean in sorted(means_dict.items(), key=lambda kv: kv[1]): | |
l_summary_dict = test_summary_dict[l] | |
c = test_summary_dict_colors[l] | |
print("label: {} ({})".format(l, c)) | |
#print("\t{}({}) - Mean".format(l_mean, metric)) | |
if l_mean == best: | |
cprint("\t{}({}) - Mean".format(l_mean, eval_metric), "green") | |
elif l_mean in top_pc: | |
cprint("\t{}({}) - Mean".format(l_mean, eval_metric), "blue") | |
elif l_mean in bottom_pc: | |
cprint("\t{}({}) - Mean".format(l_mean, eval_metric), "red") | |
else: | |
print("\t{}({})".format(l_mean, eval_metric)) | |
n_over_50 = 0 | |
if sort_test: | |
sorted_envs = sorted(l_summary_dict.items(), key=lambda kv: sort_test_set(env_name=kv[0])) | |
else: | |
sorted_envs = l_summary_dict.items() | |
for tenv, p in sorted_envs: | |
if p < 0.5: | |
print("\t{:4f}({}) - \t{}".format(p, eval_metric, tenv)) | |
else: | |
print("\t{:4f}({}) -*\t{}".format(p, eval_metric, tenv)) | |
n_over_50 += 1 | |
print("\tenv over 50 - {}/{}".format(n_over_50, len(l_summary_dict))) | |
if plot_test: | |
plt.tight_layout() | |
# plt.subplots_adjust(hspace=0.8, wspace=0.15, left=0.035, right=0.99, bottom=0.065, top=0.93) | |
plt.show() | |
if eval_filename is not None: | |
plt.subplots_adjust(hspace=0.8, wspace=0.15, left=0.15, right=0.99, bottom=0.15, top=0.93) | |
res= input(f"Save to {eval_filename} (y/n)?") | |
if res == "y": | |
f.savefig(eval_filename) | |
print(f'saved to {eval_filename}') | |
else: | |
print('not saved') | |