--- license: apache-2.0 tags: - jax - rl - jumanji --- # CVRP-V1 This model is trained on the Jumanji CVRP environment **Developed by:** InstaDeep ### Model Sources - **Repository:** [Jumanji](https://github.com/instadeepai/jumanji) - **Paper:** TBD ### How to use [Notebook](#) Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory. ``` pip install -e . ``` Below is an example script for loading and running the Jumanji model ```python import pickle import joblib import jax from hydra import compose, initialize from huggingface_hub import hf_hub_download from jumanji.training.setup_train import setup_agent, setup_env from jumanji.training.utils import first_from_device # initialise the config with initialize(version_base=None, config_path="jumanji/training/configs"): cfg = compose(config_name="config.yaml", overrides=["env=cvrp", "agent=a2c"]) # get model state from HF REPO_ID = "InstaDeepAI/jumanji-cvrp-v1-a2c-benchmark" FILENAME = "CVRP-v1_training_state" model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) with open(model_weights,"rb") as f: training_state = pickle.load(f) params = first_from_device(training_state.params_state.params) env = setup_env(cfg).unwrapped agent = setup_agent(cfg, env) policy = jax.jit(agent.make_policy(params.actor, stochastic = False)) # rollout a few episodes NUM_EPISODES = 10 states = [] key = jax.random.PRNGKey(cfg.seed) for episode in range(NUM_EPISODES): key, reset_key = jax.random.split(key) state, timestep = jax.jit(env.reset)(reset_key) while not timestep.last(): key, action_key = jax.random.split(key) observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation) action, _ = policy(observation, action_key) state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0)) states.append(state) # Freeze the terminal frame to pause the GIF. for _ in range(10): states.append(state) # animate a GIF env.animate(states, interval=150).save("./binpack.gif") # save PNG import matplotlib.pyplot as plt %matplotlib inline env.render(states[117]) plt.savefig("connector.png", dpi=300) ```