jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
import copy
from dataclasses import dataclass
from mmcv import Config
import matplotlib.pyplot as plt
import numpy as np
from pydantic import NoneBytes
import pytorch_lightning as pl
import torch
import wandb
from risk_biased.scene_dataset.loaders import SceneDataLoaders
from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
from risk_biased.scene_dataset.scene_plotter import ScenePlotter
from risk_biased.utils.cost import (
DistanceCostNumpy,
DistanceCostParams,
TTCCostNumpy,
TTCCostParams,
)
from risk_biased.utils.risk import get_risk_level_sampler
class SwitchTrainingModeCallback(pl.Callback):
"""
This callback switches between CVAE traning and biasing training for the biased_latent_cvae_model
Args:
switch_at_epoch: The number of epoch after which to make the switch. The CVAE is not trained anymore after that point.
"""
def __init__(self, switch_at_epoch: int) -> None:
super().__init__()
self._switch_at_epoch = switch_at_epoch
self._train_has_started = False
def on_train_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
"""Store the optimizer list and set the trainer to the first optimizer."""
self._optimizers = trainer.optimizers
trainer.optimizers = [self._optimizers[0]]
self._train_has_started = True
def on_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
"""
Check if the switch should be made and if so,
set the trainer on the second optimizer.
"""
if trainer.current_epoch == self._switch_at_epoch and self._train_has_started:
print("Switching to bias training.")
pl_module.set_training_mode("bias")
trainer.optimizers = [self._optimizers[1]]
def get_fast_slow_scenes(params: RandomSceneParams, n_samples: int):
"""Define and return two RandomScene objects, one initialized such that slow
pedestrians are safer and the other such that fast pedestrians are safer.
Args:
params: dataclass containing the necessary parameters for a RandomScene object
n_samples: number of samples to draw in each scene
"""
params = copy.deepcopy(params)
params.batch_size = n_samples
scene_safe_slow = RandomScene(
params,
is_torch=False,
)
percent_right = 0.8
percent_top = 1.1
angle = 5 * np.pi / 4
positions = np.array([[[percent_right, percent_top]]] * n_samples)
angles = np.array([[angle]] * n_samples)
scene_safe_slow.set_pedestrians_states(positions, angles)
scene_safe_fast = RandomScene(
params,
is_torch=False,
)
percent_right = 0.8
percent_top = 0.6
angle = 5 * np.pi / 4
positions = np.array([[[percent_right, percent_top]]] * n_samples)
angles = np.array([[angle]] * n_samples)
scene_safe_fast.set_pedestrians_states(positions, angles)
return scene_safe_fast, scene_safe_slow
@dataclass
class DrawCallbackParams:
"""
Args:
scene_params: dataclass parameters for the RandomScene
dist_cost_params: dataclass parameters for the DistanceCost
ttc_cost_params: dataclass parameters for the TTCCost
plot_interval_epoch: number of epochs between each plot drawing
histogram_interval_epoch: number of epochs between each histogram drawing
num_steps: number of time steps as defined in the config
num_steps_future: number of time steps in the future as defined in the config
risk_distribution: dict object describing a risk distribution
dt: time step size as defined in the config
"""
scene_params: RandomSceneParams
dist_cost_params: DistanceCostParams
ttc_cost_params: TTCCostParams
plot_interval_epoch: int
histogram_interval_epoch: int
num_steps: int
num_steps_future: int
risk_distribution: dict
dt: float
@staticmethod
def from_config(cfg: Config):
return DrawCallbackParams(
scene_params=RandomSceneParams.from_config(cfg),
dist_cost_params=DistanceCostParams.from_config(cfg),
ttc_cost_params=TTCCostParams.from_config(cfg),
plot_interval_epoch=cfg.plot_interval_epoch,
histogram_interval_epoch=cfg.histogram_interval_epoch,
num_steps=cfg.num_steps,
num_steps_future=cfg.num_steps_future,
risk_distribution=cfg.risk_distribution,
dt=cfg.dt,
)
class HistogramCallback(pl.Callback):
"""Logs histograms of distances, distance cost and ttc cost for the data, the predictions at risk_level=0, the predictions at risk_level=1
Args:
params: dataclass defining the necessary parameters
n_samples: Number of samples to use for the histogram plot
"""
def __init__(
self,
params: DrawCallbackParams,
n_samples=1000,
):
super().__init__()
self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
params.scene_params, n_samples
)
self.num_steps = params.num_steps
self.n_scenes = n_samples
self.sample_times = params.scene_params.sample_times
self.dist_cost_func = DistanceCostNumpy(params.dist_cost_params)
self.ttc_cost_func = TTCCostNumpy(params.ttc_cost_params)
self.histogram_interval_epoch = params.histogram_interval_epoch
self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(self.sample_times)
self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
def _log_scene(self, pl_module: pl.LightningModule, scene: RandomScene, name: str):
"""
Log in WandB three histogram for the given scene: One for the data, one for the predictions at risk_level=0 and one for the predictions at risk_level=1
Args:
pl_module: LightningModule object
scene: RandomScene object
name: name of the given scene
"""
ped_trajs = scene.get_pedestrians_trajectories()
device = pl_module.device
n_agents = ped_trajs.shape[1]
input_traj = ped_trajs[..., : self.num_steps, :]
normalized_input, offset = SceneDataLoaders.normalize_trajectory(
torch.from_numpy(input_traj.astype("float32")).contiguous().to(device)
)
mask_input = torch.ones_like(normalized_input[..., 0])
ego_history = (
torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32"))
.expand_as(normalized_input)
.contiguous()
.to(device)
)
ego_future = (
torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32"))
.expand(normalized_input.shape[0], n_agents, -1, -1)
.contiguous()
.to(device)
)
map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device)
mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device)
pred_riskier = (
pl_module.predict_step(
(
normalized_input,
mask_input,
map,
mask_map,
offset,
ego_history,
ego_future,
),
0,
risk_level=self._risk_sampler.get_highest_risk(
batch_size=self.n_scenes, device=device
)
.unsqueeze(1)
.repeat(1, n_agents),
)
.cpu()
.detach()
.numpy()
)
pred = (
pl_module.predict_step(
(
normalized_input,
mask_input,
map,
mask_map,
offset,
ego_history,
ego_future,
),
0,
risk_level=None,
)
.cpu()
.detach()
.numpy()
)
ped_trajs_pred = np.concatenate((input_traj, pred), axis=-2)
ped_trajs_pred_riskier = np.concatenate((input_traj, pred_riskier), axis=-2)
travel_distances = np.sqrt(
np.square(ped_trajs[..., -1, :] - ped_trajs[..., 0, :]).sum(-1)
)
dist_cost, dist = self.dist_cost_func(
self.ego_traj[..., self.num_steps :, :],
ped_trajs[..., self.num_steps :, :],
)
ttc_cost, (ttc, dist) = self.ttc_cost_func(
self.ego_traj[..., self.num_steps :, :],
ped_trajs[..., self.num_steps :, :],
scene.get_ego_ref_velocity(),
scene.get_pedestrians_velocities(),
)
travel_distances_pred = np.sqrt(
np.square(ped_trajs_pred[..., -1, :] - ped_trajs_pred[..., 0, :]).sum(-1)
)
dist_cost_pred, dist_pred = self.dist_cost_func(
self.ego_traj[..., self.num_steps :, :],
ped_trajs_pred[..., self.num_steps :, :],
)
sample_times = np.array(self.sample_times)
ped_velocities_pred = (
ped_trajs_pred[..., 1:, :] - ped_trajs_pred[..., :-1, :]
) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None])
ped_velocities_pred = np.concatenate(
(ped_velocities_pred[..., 0:1, :], ped_velocities_pred), -2
)
ttc_cost_pred, (ttc_pred, dist_pred) = self.ttc_cost_func(
self.ego_traj[..., self.num_steps :, :],
ped_trajs_pred[..., self.num_steps :, :],
scene.get_ego_ref_velocity(),
ped_velocities_pred[..., self.num_steps :, :],
)
travel_distances_pred_riskier = np.sqrt(
np.square(
ped_trajs_pred_riskier[..., -1, :] - ped_trajs_pred_riskier[..., 0, :]
).sum(-1)
)
dist_cost_pred_riskier, dist_pred_riskier = self.dist_cost_func(
self.ego_traj[..., self.num_steps :, :],
ped_trajs_pred_riskier[..., self.num_steps :, :],
)
sample_times = np.array(self.sample_times)
ped_velocities_pred_riskier = (
ped_trajs_pred_riskier[..., 1:, :] - ped_trajs_pred_riskier[..., :-1, :]
) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None])
ped_velocities_pred_riskier = np.concatenate(
(ped_velocities_pred_riskier[..., 0:1, :], ped_velocities_pred_riskier), -2
)
ttc_cost_pred_riskier, (ttc_pred, dist_pred_riskier) = self.ttc_cost_func(
self.ego_traj[..., self.num_steps :, :],
ped_trajs_pred_riskier[..., self.num_steps :, :],
scene.get_ego_ref_velocity(),
ped_velocities_pred_riskier[..., self.num_steps :, :],
)
data = [
[dist, dist_pred, dist_risk]
for (dist, dist_pred, dist_risk) in zip(
travel_distances.flatten(),
travel_distances_pred.flatten(),
travel_distances_pred_riskier.flatten(),
)
]
table_travel_distance = wandb.Table(
data=data,
columns=[
"Travel distance data " + name,
"Travel distance prediction " + name,
"Travel distance riskier " + name,
],
)
data = [
[cost, cost_pred, cost_risk]
for (cost, cost_pred, cost_risk) in zip(
dist_cost.flatten(),
dist_cost_pred.flatten(),
dist_cost_pred_riskier.flatten(),
)
]
table_distance_cost = wandb.Table(
data=data,
columns=[
"Distance cost data " + name,
"Distance cost prediction " + name,
"Distance cost riskier " + name,
],
)
data = [
[ttc, ttc_pred, ttc_risk]
for (ttc, ttc_pred, ttc_risk) in zip(
ttc_cost.flatten(),
ttc_cost_pred.flatten(),
ttc_cost_pred_riskier.flatten(),
)
]
table_ttc_cost = wandb.Table(
data=data,
columns=[
"TTC cost data " + name,
"TTC cost prediction " + name,
"TTC cost riskier " + name,
],
)
wandb.log(
{
"Travel distance data "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_01_bins",
data_table=table_travel_distance,
fields={
"value": "Travel distance data " + name,
"title": "Travel distance data " + name,
},
),
"Travel distance prediction "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_01_bins",
data_table=table_travel_distance,
fields={
"value": "Travel distance prediction " + name,
"title": "Travel distance prediction " + name,
},
),
"Travel distance riskier "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_01_bins",
data_table=table_travel_distance,
fields={
"value": "Travel distance riskier " + name,
"title": "Travel distance riskier " + name,
},
),
"Distance cost data "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_0025_bins",
data_table=table_distance_cost,
fields={
"value": "Distance cost data " + name,
"title": "Distance cost data " + name,
},
),
"Distance cost prediction "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_0025_bins",
data_table=table_distance_cost,
fields={
"value": "Distance cost prediction " + name,
"title": "Distance cost prediction " + name,
},
),
"Distance cost riskier "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_0025_bins",
data_table=table_distance_cost,
fields={
"value": "Distance cost riskier " + name,
"title": "Distance cost riskier " + name,
},
),
"TTC cost data "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_005_bins",
data_table=table_ttc_cost,
fields={
"value": "TTC cost data " + name,
"title": "TTC cost data " + name,
},
),
"TTC cost prediction "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_005_bins",
data_table=table_ttc_cost,
fields={
"value": "TTC cost prediction " + name,
"title": "TTC cost prediction " + name,
},
),
"TTC cost riskier "
+ name: wandb.plot_table(
vega_spec_name="jmercat/histogram_005_bins",
data_table=table_ttc_cost,
fields={
"value": "TTC cost riskier " + name,
"title": "TTC cost riskier " + name,
},
),
}
)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""After a validation at the end of every histogram_interval_epoch,
log the histograms for two scenes: the safer fast scene and the safer slow scene.
"""
if (
trainer.current_epoch % self.histogram_interval_epoch
== self.histogram_interval_epoch - 1
):
self._log_scene(pl_module, self.scene_safe_fast, name="Safer fast")
self._log_scene(pl_module, self.scene_safe_slow, name="Safer slow")
class PlotTrajCallback(pl.Callback):
"""Plot trajectory samples for two scenes:
One that is safer for the slow pedestrians
One that is safer for the fast pedestrians
Samples of ground truth, prediction, and biased predictions are superposed.
Last positions are marked to visualize the clusters.
Args:
params: dataclass containing the necessary parameters for a
n_samples: number of sample trajectories to draw
"""
def __init__(
self,
params: DrawCallbackParams,
n_samples: int = 1,
):
super().__init__()
self.n_samples = n_samples
self.num_steps = params.num_steps
self.dt = params.scene_params.dt
self.scene_params = params.scene_params
self.plot_interval_epoch = params.plot_interval_epoch
self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
params.scene_params, n_samples
)
self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(
params.scene_params.sample_times
)
self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
def _log_scene(self, epoch: int, pl_module, scene: RandomScene, name: str) -> None:
"""Add drawing of samples of prediction, biased prediction and ground truth in the scene.
Args:
epoch: current epoch calling the log
pl_module: pytorch lightning module being trained
scene: scene to draw
name: name of the scene
"""
ped_trajs = scene.get_pedestrians_trajectories()
device = pl_module.device
n_agents = ped_trajs.shape[1]
input_traj = ped_trajs[..., : self.num_steps, :]
normalized_input, offset = SceneDataLoaders.normalize_trajectory(
torch.from_numpy(input_traj.astype("float32")).contiguous().to(device)
)
mask_input = torch.ones_like(normalized_input[..., 0])
ego_history = (
torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32"))
.expand_as(normalized_input)
.contiguous()
.to(device)
)
ego_future = (
torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32"))
.expand(normalized_input.shape[0], n_agents, -1, -1)
.contiguous()
.to(device)
)
map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device)
mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device)
pred_riskier = (
pl_module.predict_step(
(
normalized_input,
mask_input,
map,
mask_map,
offset,
ego_history,
ego_future,
),
0,
risk_level=self._risk_sampler.get_highest_risk(
batch_size=self.n_samples, device=device
)
.unsqueeze(1)
.repeat(1, n_agents),
)
.cpu()
.detach()
.numpy()
)
pred = (
pl_module.predict_step(
(
normalized_input,
mask_input,
map,
mask_map,
offset,
ego_history,
ego_future,
),
0,
risk_level=None,
)
.cpu()
.detach()
.numpy()
)
fig, ax = plt.subplots()
plotter = ScenePlotter(scene, ax=ax)
fig.set_size_inches(h=scene.road_width / 3 + 1, w=scene.road_length / 3)
time = self.dt * self.num_steps
plotter.draw_scene(0, time=time)
alpha = 0.5 / np.log(self.n_samples)
plotter.draw_all_trajectories(
ped_trajs[..., self.num_steps :, :],
color="g",
alpha=alpha,
label="Future ground truth",
)
plotter.draw_all_trajectories(
input_traj, color="b", alpha=alpha, label="Past input"
)
plotter.draw_all_trajectories(
pred, color="orange", alpha=alpha, label="Prediction"
)
plotter.draw_all_trajectories(
pred_riskier, color="r", alpha=alpha, label="Prediction risk-seeking"
)
plotter.draw_legend()
plt.tight_layout()
wandb.log({"Road scene " + name: wandb.Image(fig), "epoch": epoch})
plt.close()
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""After a validation at the end of every plot_interval_epoch,
log the prediction samples for two scenes: the safer fast scene and the safer slow scene.
"""
if (
trainer.current_epoch % self.plot_interval_epoch
== self.plot_interval_epoch - 1
):
self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
self.scene_params, self.n_samples
)
self._log_scene(
trainer.current_epoch, pl_module, self.scene_safe_slow, "Safer slow"
)
self._log_scene(
trainer.current_epoch, pl_module, self.scene_safe_fast, "Safer fast"
)
# TODO: make the same kind of logs for the Waymo dataset