jmercat's picture
trust_remote_code=true when loading private dataset
6865b66 verified
from datasets import load_dataset, Dataset
import fire
from functools import partial, update_wrapper
import numpy
import os
from typing import Dict, Iterable, Tuple
import sys
import time
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from mmcv import Config
import plotly.graph_objects as go
from torch.utils.data.dataloader import DataLoader
from risk_biased.utils.load_model import get_predictor
from risk_biased.utils.torch_utils import load_weights
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
from risk_biased.predictors.biased_predictor import (
LitTrajectoryPredictor,
)
def to_numpy(**kwargs):
dic_outputs = {}
for k, v in kwargs.items():
dic_outputs[k] = v.detach().cpu().numpy()
return dic_outputs
def get_scatter_data(x, mask_x, name, **kwargs):
return [
go.Scatter(
x=x[k, mask_x[k], 0],
y=x[k, mask_x[k], 1],
showlegend=k == 0,
name=name,
**kwargs,
)
for k in range(x.shape[0])
]
def configuration_paths() -> Iterable[os.PathLike]:
working_dir = os.path.dirname(os.path.realpath(__file__))
return [
os.path.join(
working_dir,
"../../risk_biased/config",
config_file,
)
for config_file in ("learning_config.py", "waymo_config.py")
]
def load_item(index: int, dataset: Dataset, device: str = "cpu") -> Tuple:
x = torch.from_numpy(numpy.array(dataset[index]["x"]).astype(numpy.float32)).to(device)
mask_x = torch.from_numpy(numpy.array(dataset[index]["mask_x"]).astype(numpy.bool_)).to(device)
y = torch.from_numpy(numpy.array(dataset[index]["y"]).astype(numpy.float32)).to(device)
mask_y = torch.from_numpy(numpy.array(dataset[index]["mask_y"]).astype(numpy.bool_)).to(device)
mask_loss = torch.from_numpy( numpy.array(dataset[index]["mask_loss"]).astype(numpy.bool_)).to(device)
map_data = torch.from_numpy(numpy.array(dataset[index]["map_data"]).astype(numpy.float32)).to(device)
mask_map = torch.from_numpy(numpy.array(dataset[index]["mask_map"]).astype(numpy.bool_)).to(device)
offset = torch.from_numpy(numpy.array(dataset[index]["offset"]).astype(numpy.float32)).to(device)
x_ego = torch.from_numpy(numpy.array(dataset[index]["x_ego"]).astype(numpy.float32)).to(device)
y_ego = torch.from_numpy(numpy.array(dataset[index]["y_ego"]).astype(numpy.float32)).to(device)
return (x, mask_x, map_data, mask_map, offset, x_ego, y_ego), y, mask_y, mask_loss
def build_data(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int,
) -> Dict[str, go.Scatter]:
assert n_samples >= 1
batch, y, mask_y, mask_loss = load_item(index, dataset, predictor.device)
predictions = predictor.predict_step(
batch=batch,
risk_level=risk_level,
n_samples=n_samples,
)
offset = batch[4]
y = predictor._unnormalize_trajectory(y, offset)
x = predictor._unnormalize_trajectory(batch[0], offset)
numpy_data = to_numpy(
predictions=predictions,
y=y,
mask_y=mask_y,
x=x,
mask_x=batch[1],
map_data=batch[2],
mask_map=batch[3],
mask_pred=mask_loss,
)
x = numpy_data["x"][0]
mask_x = numpy_data["mask_x"][0]
y = numpy_data["y"][0]
mask_y = numpy_data["mask_y"][0]
pred = numpy_data["predictions"][0]
mask_pred = numpy_data["mask_pred"][0]
map_data = numpy_data["map_data"][0]
mask_map = numpy_data["mask_map"][0]
marker_size = 12
data_x = get_scatter_data(
x,
mask_x,
mode="lines",
line=dict(width=2, color="black"),
name="Past",
)
ego_present = get_scatter_data(
x=x[0:1, -1:],
mask_x=mask_x[0:1, -1:],
mode="markers",
marker=dict(color="blue", size=marker_size, opacity=0.5),
name="Ego",
)
agent_present = get_scatter_data(
x=x[1:2, -1:],
mask_x=mask_x[1:2, -1:],
mode="markers",
marker=dict(color="green", size=marker_size, opacity=0.5),
name="Agent",
)
data_y = get_scatter_data(
y,
mask_y,
mode="lines",
line=dict(width=2, color="green"),
name="Ground truth",
)
data_map = get_scatter_data(
map_data,
mask_map,
mode="lines",
line=dict(width=15, color="gray"),
opacity=0.3,
name="Centerline",
)
data_pred = []
forecasts_end = []
for i in range(n_samples):
cur_data_pred = get_scatter_data(
pred[:, i],
mask_pred,
mode="lines",
line=dict(width=2, color="red"),
name="Forecast",
)
data_pred += cur_data_pred
forecast_end = get_scatter_data(
pred[:, i, -1:],
mask_pred[:, -1:],
mode="markers",
marker=dict(color="red", size=marker_size/2, opacity=0.5, symbol="x"),
name="Forecast end",
)
forecasts_end += forecast_end
static_data = data_map + data_x + data_y + data_pred + ego_present + agent_present + forecasts_end
animation_opacity = 0.5
frames_x = [
go.Frame(
data=[
go.Scatter(
x=x[mask_x[:, k], k, 0],
y=x[mask_x[:, k], k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="black", size=marker_size),
showlegend=False,
),
go.Scatter(
x=x[0:1, k, 0],
y=x[0:1, k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="blue", size=marker_size),
showlegend=False,
),
]
)
for k in range(x.shape[1])
]
frames_y_pred = []
for k in range(y.shape[1]):
cur_gt_agent_data = go.Scatter(
x=y[1:2][mask_y[1:2, k], k, 0],
y=y[1:2][mask_y[1:2, k], k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="green", size=marker_size),
)
cur_gt_future_data = go.Scatter(
x=y[2:][mask_y[2:, k], k, 0],
y=y[2:][mask_y[2:, k], k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="black", size=marker_size),
)
cur_pred_data = []
for i in range(n_samples):
cur_pred_data.append(
go.Scatter(
x=pred[mask_pred[:, k], i, k, 0],
y=pred[mask_pred[:, k], i, k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="red", size=marker_size),
showlegend=False,
)
)
cur_ego_data = go.Scatter(
x=y[0:1, k, 0],
y=y[0:1, k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="blue", size=marker_size),
)
cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data]
frame = go.Frame(data=cur_data)
frames_y_pred.append(frame)
return {"frames": frames_x + frames_y_pred, "data": static_data}
def prediction_plot(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int = 1,
use_biaser: bool = True,
) -> go.Figure:
range_radius = 80
if use_biaser:
risk_level = float(risk_level)
else:
risk_level = None
layout = go.Layout(
xaxis=dict(
range=[-0.5*range_radius, 1.5*range_radius],
autorange=False,
zeroline=False,
),
yaxis=dict(
range=[-range_radius, range_radius],
autorange=False,
zeroline=False,
),
title_text="Road Scene",
hovermode="closest",
width=800,
height=600,
updatemenus=[
dict(
type="buttons",
buttons=[
dict(
label="Play",
method="animate",
args=[
None,
dict(
frame=dict(duration=100, redraw=False),
mode="immediate",
fromcurrent=True,
),
],
),
dict(
label="Pause",
method="animate",
args=[[None], {"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0}}],
)
],
)
],
)
fig = go.Figure(
**build_data(predictor, dataset, index, risk_level, n_samples),
layout=layout,
)
fig.update_geos(projection_type="equirectangular", visible=True, resolution=110)
return fig
def get_figure(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int,
) -> go.Figure:
fig = prediction_plot(
predictor, dataset, index, risk_level, n_samples, use_biaser=True
)
fig.update_layout()
return fig
def update_figure(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int,
image = None
) -> go.Figure:
fig = prediction_plot(
predictor, dataset, index, risk_level, n_samples, use_biaser=True
)
fig.update_layout()
return fig
def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]:
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
cfg = Config.fromfile(config_file)
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
predictor = load_weights(predictor, ckpt)
predictor.eval()
predictor = predictor.to(device)
return predictor
def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset:
dataset = load_dataset(data_source, split="test", trust_remote_code=True)
return dataset
def main(load_from=None, cfg_path=None):
# Define the device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Getting dataset")
dataset = load_dataset_from_hf()
if load_from is not None:
cfg = Config.fromfile(cfg_path)
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
else:
print("Getting model.")
predictor = load_predictor_from_hf(device=device)
ui_update_fn = partial(update_figure, predictor, dataset)
# Do the same thing as above but using the gradio blocks API
with gr.Blocks() as interface:
gr.Markdown(
"""
# Risk-Aware Prediction
Make predictions for the green agent with a risk-seeking bias towards the ego vehicle in blue.
The risk level is a value between 0 and 1, where 0 is not risk-seeking and 1 is the most risk-seeking.
Once the sliders are set, click the "Run" button to see the predictions.
The play button will animate the prediction over time (it is slow especially with many samples).
For more information, see the paper [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368) published at [CoRL 2022](https://corl2022.org/).
""")
initial_index = 27
initial_n_samples = 10
image = gr.Plot(get_figure(predictor, dataset, initial_index, 0, initial_n_samples))
interface.queue()
index = gr.Slider(
minimum=0,
maximum=len(dataset)-1,
step=1,
value=initial_index,
label="Index",
)
risk_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Risk")
n_samples = gr.Slider(minimum=1, maximum=20, step=1, value=initial_n_samples, label="Number of prediction samples")
button = gr.Button(label="Run")
# Removed the interactive plot because it was running on the first change and all changes made during computation were ignored
# This caused the plot to be out of sync with the sliders
# index.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
# risk_level.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
interface.launch(debug=False)
if __name__ == "__main__":
fire.Fire(main)