|
--- |
|
license: apache-2.0 |
|
tags: |
|
- jax |
|
- rl |
|
- jumanji |
|
--- |
|
|
|
# BinPack-V2 |
|
This model is trained on the Jumanji BinPack environment |
|
|
|
|
|
**Developed by:** InstaDeep |
|
|
|
### Model Sources |
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
- **Repository:** [Jumanji](https://github.com/instadeepai/jumanji) |
|
- **Paper:** TBD |
|
|
|
### How to use |
|
|
|
[Notebook](#) |
|
|
|
Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory. |
|
|
|
``` |
|
pip install -e . |
|
``` |
|
|
|
Below is an example script for loading and running the Jumanji model |
|
|
|
```python |
|
import pickle |
|
import joblib |
|
|
|
import jax |
|
from hydra import compose, initialize |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from jumanji.training.setup_train import setup_agent, setup_env |
|
from jumanji.training.utils import first_from_device |
|
|
|
# initialise the config |
|
with initialize(version_base=None, config_path="jumanji/training/configs"): |
|
cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"]) |
|
|
|
# get model state from HF |
|
REPO_ID = "InstaDeepAI/jumanji-binpack-v2-a2c-benchmark" |
|
FILENAME = "BinPack-v2_training_state" |
|
|
|
model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) |
|
|
|
with open(model_weights,"rb") as f: |
|
training_state = pickle.load(f) |
|
|
|
params = first_from_device(training_state.params_state.params) |
|
env = setup_env(cfg).unwrapped |
|
agent = setup_agent(cfg, env) |
|
policy = jax.jit(agent.make_policy(params.actor, stochastic = False)) |
|
|
|
# rollout a few episodes |
|
NUM_EPISODES = 10 |
|
|
|
states = [] |
|
key = jax.random.PRNGKey(cfg.seed) |
|
for episode in range(NUM_EPISODES): |
|
key, reset_key = jax.random.split(key) |
|
state, timestep = jax.jit(env.reset)(reset_key) |
|
while not timestep.last(): |
|
key, action_key = jax.random.split(key) |
|
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation) |
|
action, _ = policy(observation, action_key) |
|
state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0)) |
|
states.append(state) |
|
# Freeze the terminal frame to pause the GIF. |
|
for _ in range(10): |
|
states.append(state) |
|
|
|
# animate a GIF |
|
env.animate(states, interval=150).save("./binpack.gif") |
|
|
|
# save PNG |
|
import matplotlib.pyplot as plt |
|
%matplotlib inline |
|
env.render(states[117]) |
|
plt.savefig("connector.png", dpi=300) |
|
|
|
``` |