risk_biased_prediction / scripts /eval_scripts /plot_prediction_planning_evaluation.py
jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
"""plot_prediction_planning_evaluation.py --load_from <wandb ID> --seed <seed>
--scene_type <safer_fast or safer_slow> --risk_level <a list of risk-levels>
--num_samples <a list of numbers of prediction samples>
This script plots statistics of evaluation results generated by
evaluate_prediction_planning_stack.py or evaluate_prediction_planning_stack_with_replanning.py.
Add --with_replanning flag to plot results with re-planning, otherwise open-loop evaluations are
used.
"""
import argparse
import os
import pickle
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
def plot_main(
stats_dir: str,
scene_type: str,
risk_level_list: List[float],
num_prediction_samples_list: List[int],
) -> None:
if not "with_replanning" in stats_dir:
if 0.0 in risk_level_list:
plot_computation_time(
stats_dir,
scene_type,
num_prediction_samples_list=num_prediction_samples_list,
)
plot_varying_risk(
stats_dir,
scene_type,
num_prediction_samples_list=[num_prediction_samples_list[-1]],
risk_level_list=risk_level_list,
risk_in_planner=True,
)
plot_varying_risk(
stats_dir,
scene_type,
num_prediction_samples_list=[num_prediction_samples_list[-1]],
risk_level_list=risk_level_list,
risk_in_planner=False,
)
plot_policy_comparison(
stats_dir,
scene_type,
num_prediction_samples_list=num_prediction_samples_list,
risk_level_list=list(filter(lambda r: r != 0.0, risk_level_list)),
)
# How does computation time scale as we increase the number of samples?
def plot_computation_time(
stats_dir: str,
scene_type: str,
num_prediction_samples_list: List[int],
alpha_for_confint: float = 0.95,
) -> None:
risk_level = 0.0
stats_dict_zero_risk = dict()
computation_time_mean_list, computation_time_sem_list = [], []
for num_samples in num_prediction_samples_list:
file_path = os.path.join(
stats_dir,
f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}.pkl",
)
assert os.path.exists(
file_path
), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}"
with open(file_path, "rb") as f:
stats_dict_zero_risk[num_samples] = pickle.load(f)
num_episodes = _get_num_episodes(stats_dict_zero_risk[num_samples])
computation_time_list = [
stats_dict_zero_risk[num_samples][idx]["computation_time_ms"]
for idx in range(num_episodes)
]
computation_time_mean_list.append(np.mean(computation_time_list))
computation_time_sem_list.append(st.sem(computation_time_list))
# ref: https://www.statology.org/confidence-intervals-python/
confint_lower, confint_upper = st.norm.interval(
alpha=alpha_for_confint,
loc=computation_time_mean_list,
scale=computation_time_sem_list,
)
_, ax = plt.subplots(1, figsize=(6, 6))
ax.plot(
num_prediction_samples_list,
computation_time_mean_list,
color="skyblue",
linewidth=2.0,
)
ax.fill_between(
num_prediction_samples_list,
confint_upper,
confint_lower,
facecolor="skyblue",
alpha=0.3,
)
ax.set_xlabel("Number of Prediction Samples")
ax.set_ylabel("Computation Time for Prediction and Planning (ms)")
plt.show()
# How do varying risk-levels affect the safety/efficiency of the policy?
def plot_varying_risk(
stats_dir: str,
scene_type: str,
num_prediction_samples_list: List[int],
risk_level_list: List[float],
risk_in_planner: bool = False,
alpha_for_confint: float = 0.95,
) -> None:
_, ax = plt.subplots(
1,
len(num_prediction_samples_list),
figsize=(6 * len(num_prediction_samples_list), 6),
)
if not type(ax) == np.ndarray:
ax = [ax]
stats_dict = dict()
suptitle = "Safety-Efficiency Tradeoff of Optimized Policy"
if "with_replanning" in stats_dir:
suptitle += " with Replanning"
if risk_in_planner:
suptitle += " (Risk in Planner)"
else:
suptitle += " (Risk in Predictor)"
plt.suptitle(suptitle)
for (plot_idx, num_samples) in enumerate(num_prediction_samples_list):
stats_dict[num_samples] = dict()
interaction_cost_mean_list, interaction_cost_sem_list = [], []
tracking_cost_mean_list, tracking_cost_sem_list = [], []
for risk_level in risk_level_list:
if risk_level == 0.0:
file_path = os.path.join(
stats_dir,
f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}.pkl",
)
elif risk_in_planner:
file_path = os.path.join(
stats_dir,
f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl",
)
else:
file_path = os.path.join(
stats_dir,
f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl",
)
assert os.path.exists(
file_path
), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}"
with open(file_path, "rb") as f:
stats_dict[num_samples][risk_level] = pickle.load(f)
num_episodes = _get_num_episodes(stats_dict[num_samples][risk_level])
interaction_cost_list = [
stats_dict[num_samples][risk_level][idx][
"interaction_cost_ground_truth"
]
for idx in range(num_episodes)
]
interaction_cost_mean_list.append(np.mean(interaction_cost_list))
interaction_cost_sem_list.append(st.sem(interaction_cost_list))
tracking_cost_list = [
stats_dict[num_samples][risk_level][idx]["tracking_cost"]
for idx in range(num_episodes)
]
tracking_cost_mean_list.append(np.mean(tracking_cost_list))
tracking_cost_sem_list.append(st.sem(tracking_cost_list))
(
interaction_cost_confint_lower,
interaction_cost_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=interaction_cost_mean_list,
scale=interaction_cost_sem_list,
)
(tracking_cost_confint_lower, tracking_cost_confint_upper,) = st.norm.interval(
alpha=alpha_for_confint,
loc=tracking_cost_mean_list,
scale=tracking_cost_sem_list,
)
ax[plot_idx].plot(
risk_level_list,
interaction_cost_mean_list,
color="orange",
linewidth=2.0,
label="ground-truth collision cost",
)
ax[plot_idx].fill_between(
risk_level_list,
interaction_cost_confint_upper,
interaction_cost_confint_lower,
color="orange",
alpha=0.3,
)
ax[plot_idx].plot(
risk_level_list,
tracking_cost_mean_list,
color="lightgreen",
linewidth=2.0,
label="trajectory tracking cost",
)
ax[plot_idx].fill_between(
risk_level_list,
tracking_cost_confint_upper,
tracking_cost_confint_lower,
color="lightgreen",
alpha=0.3,
)
if risk_in_planner:
ax[plot_idx].set_xlabel("Risk-Sensitivity Level (in Planner)")
else:
ax[plot_idx].set_xlabel("Risk-Sensitivity Level (in Predictor)")
ax[plot_idx].set_ylabel("Cost")
ax[plot_idx].set_title(f"Number of Prediction Samples: {num_samples}")
ax[plot_idx].legend(loc="upper right")
plt.show()
# How does (risk-biased predictor + risk-neutral planner) compare with (risk-neutral predictor + risk-sensitive planner)
# in terms of characteristics of the optimized policy?
def plot_policy_comparison(
stats_dir: str,
scene_type: str,
num_prediction_samples_list: List[int],
risk_level_list: List[float],
alpha_for_confint: float = 0.95,
) -> None:
assert not 0.0 in risk_level_list
num_rows = 2 if "with_replanning" in stats_dir else 4
_, ax = plt.subplots(
num_rows, len(risk_level_list), figsize=(6 * len(risk_level_list), 6 * num_rows)
)
if len(risk_level_list) == 1:
for row_idx in range(num_rows):
ax[row_idx] = [ax[row_idx]]
suptitle = "Characteristics of Optimized Policy"
if "with_replanning" in stats_dir:
suptitle += " with Replanning"
plt.suptitle(suptitle)
predictor_stats_dict, planner_stats_dict = dict(), dict()
for (plot_idx, risk_level) in enumerate(risk_level_list):
predictor_stats_dict[risk_level], planner_stats_dict[risk_level] = (
dict(),
dict(),
)
predictor_interaction_cost_mean_list, planner_interaction_cost_mean_list = (
[],
[],
)
predictor_interaction_cost_sem_list, planner_interaction_cost_sem_list = [], []
predictor_tracking_cost_mean_list, planner_tracking_cost_mean_list = [], []
predictor_tracking_cost_sem_list, planner_tracking_cost_sem_list = [], []
if not "with_replanning" in stats_dir:
predictor_interaction_risk_mean_list, planner_interaction_risk_mean_list = (
[],
[],
)
predictor_interaction_risk_sem_list, planner_interaction_risk_sem_list = (
[],
[],
)
predictor_total_objective_mean_list, planner_total_objective_mean_list = (
[],
[],
)
predictor_total_objective_sem_list, planner_total_objective_sem_list = (
[],
[],
)
for num_samples in num_prediction_samples_list:
file_path = os.path.join(
stats_dir,
f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl",
)
assert os.path.exists(
file_path
), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}"
with open(file_path, "rb") as f:
predictor_stats_dict[risk_level][num_samples] = pickle.load(f)
predictor_num_episodes = _get_num_episodes(
predictor_stats_dict[risk_level][num_samples]
)
predictor_interaction_cost_list = [
predictor_stats_dict[risk_level][num_samples][idx][
"interaction_cost_ground_truth"
]
for idx in range(predictor_num_episodes)
]
predictor_interaction_cost_mean_list.append(
np.mean(predictor_interaction_cost_list)
)
predictor_interaction_cost_sem_list.append(
st.sem(predictor_interaction_cost_list)
)
predictor_tracking_cost_list = [
predictor_stats_dict[risk_level][num_samples][idx]["tracking_cost"]
for idx in range(predictor_num_episodes)
]
predictor_tracking_cost_mean_list.append(
np.mean(predictor_tracking_cost_list)
)
predictor_tracking_cost_sem_list.append(
st.sem(predictor_tracking_cost_list)
)
if not "with_replanning" in stats_dir:
predictor_interaction_risk_list = [
predictor_stats_dict[risk_level][num_samples][idx][
"interaction_risk"
]
for idx in range(predictor_num_episodes)
]
predictor_interaction_risk_mean_list.append(
np.mean(predictor_interaction_risk_list)
)
predictor_interaction_risk_sem_list.append(
st.sem(predictor_interaction_risk_list)
)
predictor_total_objective_list = [
interaction_risk + tracking_cost
for (interaction_risk, tracking_cost) in zip(
predictor_interaction_risk_list, predictor_tracking_cost_list
)
]
predictor_total_objective_mean_list.append(
np.mean(predictor_total_objective_list)
)
predictor_total_objective_sem_list.append(
st.sem(predictor_total_objective_list)
)
file_path = os.path.join(
stats_dir,
f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl",
)
assert os.path.exists(
file_path
), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}"
with open(file_path, "rb") as f:
planner_stats_dict[risk_level][num_samples] = pickle.load(f)
planner_num_episodes = _get_num_episodes(
planner_stats_dict[risk_level][num_samples]
)
planner_interaction_cost_list = [
planner_stats_dict[risk_level][num_samples][idx][
"interaction_cost_ground_truth"
]
for idx in range(planner_num_episodes)
]
planner_interaction_cost_mean_list.append(
np.mean(planner_interaction_cost_list)
)
planner_interaction_cost_sem_list.append(
st.sem(planner_interaction_cost_list)
)
planner_tracking_cost_list = [
planner_stats_dict[risk_level][num_samples][idx]["tracking_cost"]
for idx in range(planner_num_episodes)
]
planner_tracking_cost_mean_list.append(np.mean(planner_tracking_cost_list))
planner_tracking_cost_sem_list.append(st.sem(planner_tracking_cost_list))
if not "with_replanning" in stats_dir:
planner_interaction_risk_list = [
planner_stats_dict[risk_level][num_samples][idx]["interaction_risk"]
for idx in range(planner_num_episodes)
]
planner_interaction_risk_mean_list.append(
np.mean(planner_interaction_risk_list)
)
planner_interaction_risk_sem_list.append(
st.sem(planner_interaction_risk_list)
)
planner_total_objective_list = [
interaction_risk + tracking_cost
for (interaction_risk, tracking_cost) in zip(
planner_interaction_risk_list, planner_tracking_cost_list
)
]
planner_total_objective_mean_list.append(
np.mean(planner_total_objective_list)
)
planner_total_objective_sem_list.append(
st.sem(planner_total_objective_list)
)
(
predictor_interaction_cost_confint_lower,
predictor_interaction_cost_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=predictor_interaction_cost_mean_list,
scale=predictor_interaction_cost_sem_list,
)
(
predictor_tracking_cost_confint_lower,
predictor_tracking_cost_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=predictor_tracking_cost_mean_list,
scale=predictor_tracking_cost_sem_list,
)
if not "with_replanning" in stats_dir:
(
predictor_interaction_risk_confint_lower,
predictor_interaction_risk_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=predictor_interaction_risk_mean_list,
scale=predictor_interaction_risk_sem_list,
)
(
predictor_total_objective_confint_lower,
predictor_total_objective_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=predictor_total_objective_mean_list,
scale=predictor_total_objective_sem_list,
)
(
planner_interaction_cost_confint_lower,
planner_interaction_cost_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=planner_interaction_cost_mean_list,
scale=planner_interaction_cost_sem_list,
)
(
planner_tracking_cost_confint_lower,
planner_tracking_cost_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=planner_tracking_cost_mean_list,
scale=planner_tracking_cost_sem_list,
)
if not "with_replanning" in stats_dir:
(
planner_interaction_risk_confint_lower,
planner_interaction_risk_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=planner_interaction_risk_mean_list,
scale=planner_interaction_risk_sem_list,
)
(
planner_total_objective_confint_lower,
planner_total_objective_confint_upper,
) = st.norm.interval(
alpha=alpha_for_confint,
loc=planner_total_objective_mean_list,
scale=planner_total_objective_sem_list,
)
ax[0][plot_idx].plot(
num_prediction_samples_list,
planner_interaction_cost_mean_list,
color="skyblue",
linewidth=2.0,
label="risk in planner",
)
ax[0][plot_idx].fill_between(
num_prediction_samples_list,
planner_interaction_cost_confint_upper,
planner_interaction_cost_confint_lower,
color="skyblue",
alpha=0.3,
)
ax[0][plot_idx].plot(
num_prediction_samples_list,
predictor_interaction_cost_mean_list,
color="orange",
linewidth=2.0,
label="risk in predictor",
)
ax[0][plot_idx].fill_between(
num_prediction_samples_list,
predictor_interaction_cost_confint_upper,
predictor_interaction_cost_confint_lower,
color="orange",
alpha=0.3,
)
ax[0][plot_idx].set_xlabel("Number of Prediction Samples")
ax[0][plot_idx].set_ylabel("Ground-Truth Collision Cost")
ax[0][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}")
ax[0][plot_idx].legend(loc="upper right")
ax[0][plot_idx].set_xscale("log")
ax[1][plot_idx].plot(
num_prediction_samples_list,
planner_tracking_cost_mean_list,
color="skyblue",
linewidth=2.0,
label="risk in planner",
)
ax[1][plot_idx].fill_between(
num_prediction_samples_list,
planner_tracking_cost_confint_upper,
planner_tracking_cost_confint_lower,
color="skyblue",
alpha=0.3,
)
ax[1][plot_idx].plot(
num_prediction_samples_list,
predictor_tracking_cost_mean_list,
color="orange",
linewidth=2.0,
label="risk in predictor",
)
ax[1][plot_idx].fill_between(
num_prediction_samples_list,
predictor_tracking_cost_confint_upper,
predictor_tracking_cost_confint_lower,
color="orange",
alpha=0.3,
)
ax[1][plot_idx].set_xlabel("Number of Prediction Samples")
ax[1][plot_idx].set_ylabel("Trajectory Tracking Cost")
# ax[1][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}")
ax[1][plot_idx].legend(loc="lower right")
ax[1][plot_idx].set_xscale("log")
if not "with_replanning" in stats_dir:
ax[2][plot_idx].plot(
num_prediction_samples_list,
planner_interaction_risk_mean_list,
color="skyblue",
linewidth=2.0,
label="risk in planner",
)
ax[2][plot_idx].fill_between(
num_prediction_samples_list,
planner_interaction_risk_confint_upper,
planner_interaction_risk_confint_lower,
color="skyblue",
alpha=0.3,
)
ax[2][plot_idx].plot(
num_prediction_samples_list,
predictor_interaction_risk_mean_list,
color="orange",
linewidth=2.0,
label="risk in predictor",
)
ax[2][plot_idx].fill_between(
num_prediction_samples_list,
predictor_interaction_risk_confint_upper,
predictor_interaction_risk_confint_lower,
color="orange",
alpha=0.3,
)
ax[2][plot_idx].set_xlabel("Number of Prediction Samples")
ax[2][plot_idx].set_ylabel("Collision Risk")
# ax[2][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}")
ax[2][plot_idx].legend(loc="upper right")
ax[2][plot_idx].set_xscale("log")
ax[3][plot_idx].plot(
num_prediction_samples_list,
planner_total_objective_mean_list,
color="skyblue",
linewidth=2.0,
label="risk in planner",
)
ax[3][plot_idx].fill_between(
num_prediction_samples_list,
planner_total_objective_confint_upper,
planner_total_objective_confint_lower,
color="skyblue",
alpha=0.3,
)
ax[3][plot_idx].plot(
num_prediction_samples_list,
predictor_total_objective_mean_list,
color="orange",
linewidth=2.0,
label="risk in predictor",
)
ax[3][plot_idx].fill_between(
num_prediction_samples_list,
predictor_total_objective_confint_upper,
predictor_total_objective_confint_lower,
color="orange",
alpha=0.3,
)
ax[3][plot_idx].set_xlabel("Number of Prediction Samples")
ax[3][plot_idx].set_ylabel("Planner's Total Objective")
# ax[3][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}")
ax[3][plot_idx].legend(loc="upper right")
ax[3][plot_idx].set_xscale("log")
plt.show()
def _get_num_episodes(stats_dict: dict):
return max(filter(lambda key: type(key) == int, stats_dict)) + 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="visualize evaluation result of evaluate_prediction_planning_stack.py"
)
parser.add_argument(
"--load_from",
type=str,
required=True,
help="WandB ID for specification of trained predictor",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=0,
)
parser.add_argument(
"--scene_type",
type=str,
choices=["safer_fast", "safer_slow"],
required=True,
)
parser.add_argument(
"--with_replanning",
action="store_true",
)
parser.add_argument(
"--risk_level",
type=float,
nargs="+",
help="Risk-sensitivity level(s) to test",
default=[0.95, 1.0],
)
parser.add_argument(
"--num_samples",
type=int,
nargs="+",
help="Number(s) of prediction samples to test",
default=[1, 4, 16, 64, 256, 1024],
)
parser.add_argument(
"--force_config",
action="store_true",
help="""Use this flag to force the use of the local config file
when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
In any case the parameters can be overwritten with an argparse argument.""",
)
args = parser.parse_args()
dir_name = (
"planner_eval_with_replanning" if args.with_replanning else "planner_eval"
)
stats_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"logs",
dir_name,
f"run-{args.load_from}_{args.seed}",
)
postfix_string = "_with_replanning" if args.with_replanning else ""
assert os.path.exists(
stats_dir
), f"{stats_dir} does not exist. Did you run 'evaluate_prediction_planning_stack{postfix_string}.py --load_from {args.load_from} --seed {args.seed}' ?"
plot_main(stats_dir, args.scene_type, args.risk_level, args.num_samples)