from huggingface_hub import hf_hub_url, cached_download from mmcv import Config import torch from risk_biased.utils.load_model import get_predictor from risk_biased.utils.torch_utils import load_weights from risk_biased.utils.waymo_dataloader import WaymoDataloaders config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py") ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu") cfg = Config.fromfile(config_file) predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) predictor = load_weights(predictor, ckpt) print("Model loaded")