"""plot_prediction_planning_evaluation.py --load_from --seed --scene_type --risk_level --num_samples This script plots statistics of evaluation results generated by evaluate_prediction_planning_stack.py or evaluate_prediction_planning_stack_with_replanning.py. Add --with_replanning flag to plot results with re-planning, otherwise open-loop evaluations are used. """ import argparse import os import pickle from typing import List import matplotlib.pyplot as plt import numpy as np import scipy.stats as st def plot_main( stats_dir: str, scene_type: str, risk_level_list: List[float], num_prediction_samples_list: List[int], ) -> None: if not "with_replanning" in stats_dir: if 0.0 in risk_level_list: plot_computation_time( stats_dir, scene_type, num_prediction_samples_list=num_prediction_samples_list, ) plot_varying_risk( stats_dir, scene_type, num_prediction_samples_list=[num_prediction_samples_list[-1]], risk_level_list=risk_level_list, risk_in_planner=True, ) plot_varying_risk( stats_dir, scene_type, num_prediction_samples_list=[num_prediction_samples_list[-1]], risk_level_list=risk_level_list, risk_in_planner=False, ) plot_policy_comparison( stats_dir, scene_type, num_prediction_samples_list=num_prediction_samples_list, risk_level_list=list(filter(lambda r: r != 0.0, risk_level_list)), ) # How does computation time scale as we increase the number of samples? def plot_computation_time( stats_dir: str, scene_type: str, num_prediction_samples_list: List[int], alpha_for_confint: float = 0.95, ) -> None: risk_level = 0.0 stats_dict_zero_risk = dict() computation_time_mean_list, computation_time_sem_list = [], [] for num_samples in num_prediction_samples_list: file_path = os.path.join( stats_dir, f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}.pkl", ) assert os.path.exists( file_path ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" with open(file_path, "rb") as f: stats_dict_zero_risk[num_samples] = pickle.load(f) num_episodes = _get_num_episodes(stats_dict_zero_risk[num_samples]) computation_time_list = [ stats_dict_zero_risk[num_samples][idx]["computation_time_ms"] for idx in range(num_episodes) ] computation_time_mean_list.append(np.mean(computation_time_list)) computation_time_sem_list.append(st.sem(computation_time_list)) # ref: https://www.statology.org/confidence-intervals-python/ confint_lower, confint_upper = st.norm.interval( alpha=alpha_for_confint, loc=computation_time_mean_list, scale=computation_time_sem_list, ) _, ax = plt.subplots(1, figsize=(6, 6)) ax.plot( num_prediction_samples_list, computation_time_mean_list, color="skyblue", linewidth=2.0, ) ax.fill_between( num_prediction_samples_list, confint_upper, confint_lower, facecolor="skyblue", alpha=0.3, ) ax.set_xlabel("Number of Prediction Samples") ax.set_ylabel("Computation Time for Prediction and Planning (ms)") plt.show() # How do varying risk-levels affect the safety/efficiency of the policy? def plot_varying_risk( stats_dir: str, scene_type: str, num_prediction_samples_list: List[int], risk_level_list: List[float], risk_in_planner: bool = False, alpha_for_confint: float = 0.95, ) -> None: _, ax = plt.subplots( 1, len(num_prediction_samples_list), figsize=(6 * len(num_prediction_samples_list), 6), ) if not type(ax) == np.ndarray: ax = [ax] stats_dict = dict() suptitle = "Safety-Efficiency Tradeoff of Optimized Policy" if "with_replanning" in stats_dir: suptitle += " with Replanning" if risk_in_planner: suptitle += " (Risk in Planner)" else: suptitle += " (Risk in Predictor)" plt.suptitle(suptitle) for (plot_idx, num_samples) in enumerate(num_prediction_samples_list): stats_dict[num_samples] = dict() interaction_cost_mean_list, interaction_cost_sem_list = [], [] tracking_cost_mean_list, tracking_cost_sem_list = [], [] for risk_level in risk_level_list: if risk_level == 0.0: file_path = os.path.join( stats_dir, f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}.pkl", ) elif risk_in_planner: file_path = os.path.join( stats_dir, f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl", ) else: file_path = os.path.join( stats_dir, f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl", ) assert os.path.exists( file_path ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" with open(file_path, "rb") as f: stats_dict[num_samples][risk_level] = pickle.load(f) num_episodes = _get_num_episodes(stats_dict[num_samples][risk_level]) interaction_cost_list = [ stats_dict[num_samples][risk_level][idx][ "interaction_cost_ground_truth" ] for idx in range(num_episodes) ] interaction_cost_mean_list.append(np.mean(interaction_cost_list)) interaction_cost_sem_list.append(st.sem(interaction_cost_list)) tracking_cost_list = [ stats_dict[num_samples][risk_level][idx]["tracking_cost"] for idx in range(num_episodes) ] tracking_cost_mean_list.append(np.mean(tracking_cost_list)) tracking_cost_sem_list.append(st.sem(tracking_cost_list)) ( interaction_cost_confint_lower, interaction_cost_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=interaction_cost_mean_list, scale=interaction_cost_sem_list, ) (tracking_cost_confint_lower, tracking_cost_confint_upper,) = st.norm.interval( alpha=alpha_for_confint, loc=tracking_cost_mean_list, scale=tracking_cost_sem_list, ) ax[plot_idx].plot( risk_level_list, interaction_cost_mean_list, color="orange", linewidth=2.0, label="ground-truth collision cost", ) ax[plot_idx].fill_between( risk_level_list, interaction_cost_confint_upper, interaction_cost_confint_lower, color="orange", alpha=0.3, ) ax[plot_idx].plot( risk_level_list, tracking_cost_mean_list, color="lightgreen", linewidth=2.0, label="trajectory tracking cost", ) ax[plot_idx].fill_between( risk_level_list, tracking_cost_confint_upper, tracking_cost_confint_lower, color="lightgreen", alpha=0.3, ) if risk_in_planner: ax[plot_idx].set_xlabel("Risk-Sensitivity Level (in Planner)") else: ax[plot_idx].set_xlabel("Risk-Sensitivity Level (in Predictor)") ax[plot_idx].set_ylabel("Cost") ax[plot_idx].set_title(f"Number of Prediction Samples: {num_samples}") ax[plot_idx].legend(loc="upper right") plt.show() # How does (risk-biased predictor + risk-neutral planner) compare with (risk-neutral predictor + risk-sensitive planner) # in terms of characteristics of the optimized policy? def plot_policy_comparison( stats_dir: str, scene_type: str, num_prediction_samples_list: List[int], risk_level_list: List[float], alpha_for_confint: float = 0.95, ) -> None: assert not 0.0 in risk_level_list num_rows = 2 if "with_replanning" in stats_dir else 4 _, ax = plt.subplots( num_rows, len(risk_level_list), figsize=(6 * len(risk_level_list), 6 * num_rows) ) if len(risk_level_list) == 1: for row_idx in range(num_rows): ax[row_idx] = [ax[row_idx]] suptitle = "Characteristics of Optimized Policy" if "with_replanning" in stats_dir: suptitle += " with Replanning" plt.suptitle(suptitle) predictor_stats_dict, planner_stats_dict = dict(), dict() for (plot_idx, risk_level) in enumerate(risk_level_list): predictor_stats_dict[risk_level], planner_stats_dict[risk_level] = ( dict(), dict(), ) predictor_interaction_cost_mean_list, planner_interaction_cost_mean_list = ( [], [], ) predictor_interaction_cost_sem_list, planner_interaction_cost_sem_list = [], [] predictor_tracking_cost_mean_list, planner_tracking_cost_mean_list = [], [] predictor_tracking_cost_sem_list, planner_tracking_cost_sem_list = [], [] if not "with_replanning" in stats_dir: predictor_interaction_risk_mean_list, planner_interaction_risk_mean_list = ( [], [], ) predictor_interaction_risk_sem_list, planner_interaction_risk_sem_list = ( [], [], ) predictor_total_objective_mean_list, planner_total_objective_mean_list = ( [], [], ) predictor_total_objective_sem_list, planner_total_objective_sem_list = ( [], [], ) for num_samples in num_prediction_samples_list: file_path = os.path.join( stats_dir, f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl", ) assert os.path.exists( file_path ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" with open(file_path, "rb") as f: predictor_stats_dict[risk_level][num_samples] = pickle.load(f) predictor_num_episodes = _get_num_episodes( predictor_stats_dict[risk_level][num_samples] ) predictor_interaction_cost_list = [ predictor_stats_dict[risk_level][num_samples][idx][ "interaction_cost_ground_truth" ] for idx in range(predictor_num_episodes) ] predictor_interaction_cost_mean_list.append( np.mean(predictor_interaction_cost_list) ) predictor_interaction_cost_sem_list.append( st.sem(predictor_interaction_cost_list) ) predictor_tracking_cost_list = [ predictor_stats_dict[risk_level][num_samples][idx]["tracking_cost"] for idx in range(predictor_num_episodes) ] predictor_tracking_cost_mean_list.append( np.mean(predictor_tracking_cost_list) ) predictor_tracking_cost_sem_list.append( st.sem(predictor_tracking_cost_list) ) if not "with_replanning" in stats_dir: predictor_interaction_risk_list = [ predictor_stats_dict[risk_level][num_samples][idx][ "interaction_risk" ] for idx in range(predictor_num_episodes) ] predictor_interaction_risk_mean_list.append( np.mean(predictor_interaction_risk_list) ) predictor_interaction_risk_sem_list.append( st.sem(predictor_interaction_risk_list) ) predictor_total_objective_list = [ interaction_risk + tracking_cost for (interaction_risk, tracking_cost) in zip( predictor_interaction_risk_list, predictor_tracking_cost_list ) ] predictor_total_objective_mean_list.append( np.mean(predictor_total_objective_list) ) predictor_total_objective_sem_list.append( st.sem(predictor_total_objective_list) ) file_path = os.path.join( stats_dir, f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl", ) assert os.path.exists( file_path ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" with open(file_path, "rb") as f: planner_stats_dict[risk_level][num_samples] = pickle.load(f) planner_num_episodes = _get_num_episodes( planner_stats_dict[risk_level][num_samples] ) planner_interaction_cost_list = [ planner_stats_dict[risk_level][num_samples][idx][ "interaction_cost_ground_truth" ] for idx in range(planner_num_episodes) ] planner_interaction_cost_mean_list.append( np.mean(planner_interaction_cost_list) ) planner_interaction_cost_sem_list.append( st.sem(planner_interaction_cost_list) ) planner_tracking_cost_list = [ planner_stats_dict[risk_level][num_samples][idx]["tracking_cost"] for idx in range(planner_num_episodes) ] planner_tracking_cost_mean_list.append(np.mean(planner_tracking_cost_list)) planner_tracking_cost_sem_list.append(st.sem(planner_tracking_cost_list)) if not "with_replanning" in stats_dir: planner_interaction_risk_list = [ planner_stats_dict[risk_level][num_samples][idx]["interaction_risk"] for idx in range(planner_num_episodes) ] planner_interaction_risk_mean_list.append( np.mean(planner_interaction_risk_list) ) planner_interaction_risk_sem_list.append( st.sem(planner_interaction_risk_list) ) planner_total_objective_list = [ interaction_risk + tracking_cost for (interaction_risk, tracking_cost) in zip( planner_interaction_risk_list, planner_tracking_cost_list ) ] planner_total_objective_mean_list.append( np.mean(planner_total_objective_list) ) planner_total_objective_sem_list.append( st.sem(planner_total_objective_list) ) ( predictor_interaction_cost_confint_lower, predictor_interaction_cost_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=predictor_interaction_cost_mean_list, scale=predictor_interaction_cost_sem_list, ) ( predictor_tracking_cost_confint_lower, predictor_tracking_cost_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=predictor_tracking_cost_mean_list, scale=predictor_tracking_cost_sem_list, ) if not "with_replanning" in stats_dir: ( predictor_interaction_risk_confint_lower, predictor_interaction_risk_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=predictor_interaction_risk_mean_list, scale=predictor_interaction_risk_sem_list, ) ( predictor_total_objective_confint_lower, predictor_total_objective_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=predictor_total_objective_mean_list, scale=predictor_total_objective_sem_list, ) ( planner_interaction_cost_confint_lower, planner_interaction_cost_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=planner_interaction_cost_mean_list, scale=planner_interaction_cost_sem_list, ) ( planner_tracking_cost_confint_lower, planner_tracking_cost_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=planner_tracking_cost_mean_list, scale=planner_tracking_cost_sem_list, ) if not "with_replanning" in stats_dir: ( planner_interaction_risk_confint_lower, planner_interaction_risk_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=planner_interaction_risk_mean_list, scale=planner_interaction_risk_sem_list, ) ( planner_total_objective_confint_lower, planner_total_objective_confint_upper, ) = st.norm.interval( alpha=alpha_for_confint, loc=planner_total_objective_mean_list, scale=planner_total_objective_sem_list, ) ax[0][plot_idx].plot( num_prediction_samples_list, planner_interaction_cost_mean_list, color="skyblue", linewidth=2.0, label="risk in planner", ) ax[0][plot_idx].fill_between( num_prediction_samples_list, planner_interaction_cost_confint_upper, planner_interaction_cost_confint_lower, color="skyblue", alpha=0.3, ) ax[0][plot_idx].plot( num_prediction_samples_list, predictor_interaction_cost_mean_list, color="orange", linewidth=2.0, label="risk in predictor", ) ax[0][plot_idx].fill_between( num_prediction_samples_list, predictor_interaction_cost_confint_upper, predictor_interaction_cost_confint_lower, color="orange", alpha=0.3, ) ax[0][plot_idx].set_xlabel("Number of Prediction Samples") ax[0][plot_idx].set_ylabel("Ground-Truth Collision Cost") ax[0][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") ax[0][plot_idx].legend(loc="upper right") ax[0][plot_idx].set_xscale("log") ax[1][plot_idx].plot( num_prediction_samples_list, planner_tracking_cost_mean_list, color="skyblue", linewidth=2.0, label="risk in planner", ) ax[1][plot_idx].fill_between( num_prediction_samples_list, planner_tracking_cost_confint_upper, planner_tracking_cost_confint_lower, color="skyblue", alpha=0.3, ) ax[1][plot_idx].plot( num_prediction_samples_list, predictor_tracking_cost_mean_list, color="orange", linewidth=2.0, label="risk in predictor", ) ax[1][plot_idx].fill_between( num_prediction_samples_list, predictor_tracking_cost_confint_upper, predictor_tracking_cost_confint_lower, color="orange", alpha=0.3, ) ax[1][plot_idx].set_xlabel("Number of Prediction Samples") ax[1][plot_idx].set_ylabel("Trajectory Tracking Cost") # ax[1][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") ax[1][plot_idx].legend(loc="lower right") ax[1][plot_idx].set_xscale("log") if not "with_replanning" in stats_dir: ax[2][plot_idx].plot( num_prediction_samples_list, planner_interaction_risk_mean_list, color="skyblue", linewidth=2.0, label="risk in planner", ) ax[2][plot_idx].fill_between( num_prediction_samples_list, planner_interaction_risk_confint_upper, planner_interaction_risk_confint_lower, color="skyblue", alpha=0.3, ) ax[2][plot_idx].plot( num_prediction_samples_list, predictor_interaction_risk_mean_list, color="orange", linewidth=2.0, label="risk in predictor", ) ax[2][plot_idx].fill_between( num_prediction_samples_list, predictor_interaction_risk_confint_upper, predictor_interaction_risk_confint_lower, color="orange", alpha=0.3, ) ax[2][plot_idx].set_xlabel("Number of Prediction Samples") ax[2][plot_idx].set_ylabel("Collision Risk") # ax[2][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") ax[2][plot_idx].legend(loc="upper right") ax[2][plot_idx].set_xscale("log") ax[3][plot_idx].plot( num_prediction_samples_list, planner_total_objective_mean_list, color="skyblue", linewidth=2.0, label="risk in planner", ) ax[3][plot_idx].fill_between( num_prediction_samples_list, planner_total_objective_confint_upper, planner_total_objective_confint_lower, color="skyblue", alpha=0.3, ) ax[3][plot_idx].plot( num_prediction_samples_list, predictor_total_objective_mean_list, color="orange", linewidth=2.0, label="risk in predictor", ) ax[3][plot_idx].fill_between( num_prediction_samples_list, predictor_total_objective_confint_upper, predictor_total_objective_confint_lower, color="orange", alpha=0.3, ) ax[3][plot_idx].set_xlabel("Number of Prediction Samples") ax[3][plot_idx].set_ylabel("Planner's Total Objective") # ax[3][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") ax[3][plot_idx].legend(loc="upper right") ax[3][plot_idx].set_xscale("log") plt.show() def _get_num_episodes(stats_dict: dict): return max(filter(lambda key: type(key) == int, stats_dict)) + 1 if __name__ == "__main__": parser = argparse.ArgumentParser( description="visualize evaluation result of evaluate_prediction_planning_stack.py" ) parser.add_argument( "--load_from", type=str, required=True, help="WandB ID for specification of trained predictor", ) parser.add_argument( "--seed", type=int, required=False, default=0, ) parser.add_argument( "--scene_type", type=str, choices=["safer_fast", "safer_slow"], required=True, ) parser.add_argument( "--with_replanning", action="store_true", ) parser.add_argument( "--risk_level", type=float, nargs="+", help="Risk-sensitivity level(s) to test", default=[0.95, 1.0], ) parser.add_argument( "--num_samples", type=int, nargs="+", help="Number(s) of prediction samples to test", default=[1, 4, 16, 64, 256, 1024], ) parser.add_argument( "--force_config", action="store_true", help="""Use this flag to force the use of the local config file when loading a model from a checkpoint. Otherwise the checkpoint config file is used. In any case the parameters can be overwritten with an argparse argument.""", ) args = parser.parse_args() dir_name = ( "planner_eval_with_replanning" if args.with_replanning else "planner_eval" ) stats_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), "logs", dir_name, f"run-{args.load_from}_{args.seed}", ) postfix_string = "_with_replanning" if args.with_replanning else "" assert os.path.exists( stats_dir ), f"{stats_dir} does not exist. Did you run 'evaluate_prediction_planning_stack{postfix_string}.py --load_from {args.load_from} --seed {args.seed}' ?" plot_main(stats_dir, args.scene_type, args.risk_level, args.num_samples)