d-byrne's picture
typo
36ca4d6
metadata
license: apache-2.0
tags:
  - jax
  - rl
  - jumanji

BinPack-V2

This model is trained on the Jumanji BinPack environment

Developed by: InstaDeep

Model Sources

How to use

Notebook

Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.

pip install -e .

Below is an example script for loading and running the Jumanji model

import pickle
import joblib

import jax
from hydra import compose, initialize
from huggingface_hub import hf_hub_download


from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device

# initialise the config
with initialize(version_base=None, config_path="jumanji/training/configs"):
    cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])

# get model state from HF
REPO_ID = "InstaDeepAI/jumanji-binpack-v2-a2c-benchmark"
FILENAME = "BinPack-v2_training_state"

model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

with open(model_weights,"rb") as f:
    training_state = pickle.load(f)

params = first_from_device(training_state.params_state.params)
env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))

# rollout a few episodes
NUM_EPISODES = 10

states = []
key = jax.random.PRNGKey(cfg.seed)
for episode in range(NUM_EPISODES):
    key, reset_key = jax.random.split(key) 
    state, timestep = jax.jit(env.reset)(reset_key)
    while not timestep.last():
        key, action_key = jax.random.split(key)
        observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
        action, _ = policy(observation, action_key)
        state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
        states.append(state)
    # Freeze the terminal frame to pause the GIF.
    for _ in range(10):
        states.append(state)

# animate a GIF
env.animate(states, interval=150).save("./binpack.gif")

# save PNG
import matplotlib.pyplot as plt
%matplotlib inline
env.render(states[117])
plt.savefig("connector.png", dpi=300)