tree3po commited on
Commit
e0f25ed
·
verified ·
1 Parent(s): d64c84d

Upload 190 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. Kinetix/.gitignore +194 -0
  3. Kinetix/.pre-commit-config.yaml +7 -0
  4. Kinetix/LICENSE +19 -0
  5. Kinetix/README.md +217 -0
  6. Kinetix/configs/editor.yaml +22 -0
  7. Kinetix/configs/env/entity.yaml +3 -0
  8. Kinetix/configs/env/symbolic.yaml +3 -0
  9. Kinetix/configs/env_size/custom.yaml +3 -0
  10. Kinetix/configs/env_size/l.yaml +8 -0
  11. Kinetix/configs/env_size/m.yaml +8 -0
  12. Kinetix/configs/env_size/s.yaml +8 -0
  13. Kinetix/configs/eval/eval_all.yaml +82 -0
  14. Kinetix/configs/eval/eval_auto.yaml +4 -0
  15. Kinetix/configs/eval/eval_general.yaml +7 -0
  16. Kinetix/configs/eval/l.yaml +46 -0
  17. Kinetix/configs/eval/m.yaml +30 -0
  18. Kinetix/configs/eval/mujoco.yaml +13 -0
  19. Kinetix/configs/eval/s.yaml +16 -0
  20. Kinetix/configs/eval_env_size/l.yaml +7 -0
  21. Kinetix/configs/eval_env_size/m.yaml +7 -0
  22. Kinetix/configs/eval_env_size/s.yaml +7 -0
  23. Kinetix/configs/learning/ppo-base.yaml +20 -0
  24. Kinetix/configs/learning/ppo-rnn.yaml +2 -0
  25. Kinetix/configs/learning/ppo-sfl.yaml +1 -0
  26. Kinetix/configs/learning/ppo-ued.yaml +2 -0
  27. Kinetix/configs/misc/misc.yaml +16 -0
  28. Kinetix/configs/model/model-base.yaml +4 -0
  29. Kinetix/configs/model/model-transformer.yaml +6 -0
  30. Kinetix/configs/plr.yaml +17 -0
  31. Kinetix/configs/ppo.yaml +20 -0
  32. Kinetix/configs/sfl.yaml +21 -0
  33. Kinetix/configs/train_levels/l.yaml +44 -0
  34. Kinetix/configs/train_levels/m.yaml +28 -0
  35. Kinetix/configs/train_levels/mujoco.yaml +11 -0
  36. Kinetix/configs/train_levels/random.yaml +2 -0
  37. Kinetix/configs/train_levels/s.yaml +14 -0
  38. Kinetix/configs/train_levels/train_all.yaml +80 -0
  39. Kinetix/configs/ued/accel.yaml +16 -0
  40. Kinetix/configs/ued/plr.yaml +17 -0
  41. Kinetix/configs/ued/sfl.yaml +9 -0
  42. Kinetix/docs/README.md +83 -0
  43. Kinetix/docs/configs.md +179 -0
  44. Kinetix/examples/example_premade_level_replay.py +46 -0
  45. Kinetix/examples/example_random_level_replay.py +51 -0
  46. Kinetix/experiments/plr.py +1143 -0
  47. Kinetix/experiments/ppo.py +468 -0
  48. Kinetix/experiments/sfl.py +1067 -0
  49. Kinetix/images/bb.gif +0 -0
  50. Kinetix/images/cartpole.gif +0 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Kinetix/images/general_2.gif filter=lfs diff=lfs merge=lfs -text
37
+ Kinetix/images/kinetix_logo.gif filter=lfs diff=lfs merge=lfs -text
38
+ Kinetix/images/random_1.gif filter=lfs diff=lfs merge=lfs -text
39
+ Kinetix/images/random_3.gif filter=lfs diff=lfs merge=lfs -text
40
+ Kinetix/images/random_4.gif filter=lfs diff=lfs merge=lfs -text
41
+ Kinetix/images/random_5.gif filter=lfs diff=lfs merge=lfs -text
42
+ Kinetix/images/random_6.gif filter=lfs diff=lfs merge=lfs -text
43
+ Kinetix/images/random_7.gif filter=lfs diff=lfs merge=lfs -text
Kinetix/.gitignore ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tmp/
2
+ wandb/
3
+ runs/
4
+
5
+ play_data
6
+ checkpoints
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache*
53
+ .cache_*
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ # For a library or package, you might want to ignore these files since the code is
95
+ # intended to run in multiple environments; otherwise, check them in:
96
+ # .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ #poetry.lock
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ #pdm.lock
115
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116
+ # in version control.
117
+ # https://pdm.fming.dev/#use-with-ide
118
+ .pdm.toml
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+
140
+ !configs/env
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+
160
+ # pytype static type analyzer
161
+ .pytype/
162
+
163
+ # Cython debug symbols
164
+ cython_debug/
165
+
166
+ # PyCharm
167
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
170
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171
+ .idea/
172
+ texture_cache.pbz2
173
+ texture_cache*.pbz2
174
+ profile*
175
+ wandb_key
176
+ test.py
177
+ outputs
178
+ lol*
179
+ .cache-location
180
+ experiments/ppo_old.py
181
+ .bash_history
182
+ logs/
183
+ .vscode
184
+
185
+ kinetix/util/old_learning_with_mask.py
186
+ offline/datasets/*.pkl
187
+ all_sweeps
188
+ worlds/games
189
+
190
+ artifacts
191
+ log*_*
192
+ kinetix/analysis/test*.py
193
+ slurm-*.out
194
+ results/
Kinetix/.pre-commit-config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 22.3.0
4
+ hooks:
5
+ - id: black
6
+ language_version: python3
7
+ args: [--line-length=120]
Kinetix/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2024 Michael Matthews
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
Kinetix/README.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="middle">
2
+ <img src="images/kinetix_logo.gif" width="500" />
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href= "https://pypi.org/project/jax2d/">
7
+ <img src="https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue" /></a>
8
+ <a href= "https://github.com/FLAIROx/Kinetix/blob/main/LICENSE">
9
+ <img src="https://img.shields.io/badge/License-MIT-yellow" /></a>
10
+ <a href= "https://github.com/psf/black">
11
+ <img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a>
12
+ <a href= "https://kinetix-env.github.io/">
13
+ <img src="https://img.shields.io/badge/online-editor-purple" /></a>
14
+ <a href= "https://arxiv.org/abs/2410.23208">
15
+ <img src="https://img.shields.io/badge/arxiv-2410.23208-b31b1b" /></a>
16
+ <a href= "./docs/README.md">
17
+ <img src="https://img.shields.io/badge/docs-green" /></a>
18
+ </p>
19
+
20
+ # Kinetix
21
+
22
+ Kinetix is a framework for reinforcement learning in a 2D rigid-body physics world, written entirely in [JAX](https://github.com/jax-ml/jax).
23
+ Kinetix can represent a huge array of physics-based tasks within a unified framework.
24
+ We use Kinetix to investigate the training of large, general reinforcement learning agents by procedurally generating millions of tasks for training.
25
+ You can play with Kinetix in our [online editor](https://kinetix-env.github.io/), or have a look at the JAX [physics engine](https://github.com/MichaelTMatthews/Jax2D) and [graphics library](https://github.com/FLAIROx/JaxGL) we made for Kinetix. Finally, see our [docs](./docs/README.md) for more information and more in-depth examples.
26
+
27
+ <p align="middle">
28
+ <img src="images/bb.gif" width="200" />
29
+ <img src="images/cartpole.gif" width="200" />
30
+ <img src="images/grasper.gif" width="200" />
31
+ </p>
32
+ <p align="middle">
33
+ <img src="images/hc.gif" width="200" />
34
+ <img src="images/hopper.gif" width="200" />
35
+ <img src="images/ll.gif" width="200" />
36
+ </p>
37
+
38
+ <p align="middle">
39
+ <b>The above shows specialist agents trained on their respective levels.</b>
40
+ </p>
41
+
42
+ # 📊 Paper TL; DR
43
+
44
+
45
+
46
+ We train a general agent on millions of procedurally generated physics tasks.
47
+ Every task has the same goal: make the <span style="color:green">green</span> and <span style="color:blue">blue</span> touch, without <span style="color:green">green</span> touching <span style="color:red">red</span>.
48
+ The agent can act through applying torque via motors and force via thrusters.
49
+
50
+ <p align="middle">
51
+ <img src="images/random_1.gif" width="200" />
52
+ <img src="images/random_5.gif" width="200" />
53
+ <img src="images/random_3.gif" width="200" />
54
+ </p>
55
+ <p align="middle">
56
+ <img src="images/random_4.gif" width="200" />
57
+ <img src="images/random_6.gif" width="200" />
58
+ <img src="images/random_7.gif" width="200" />
59
+ </p>
60
+
61
+ <p align="middle">
62
+ <b>The above shows a general agent zero-shotting unseen randomly generated levels.</b>
63
+ </p>
64
+
65
+ We then investigate the transfer capabilities of this agent to unseen handmade levels.
66
+ We find that the agent can zero-shot simple physics problems, but still struggles with harder tasks.
67
+
68
+ <p align="middle">
69
+ <img src="images/general_1.gif" width="200" />
70
+ <img src="images/general_2.gif" width="200" />
71
+ <img src="images/general_3.gif" width="200" />
72
+ </p>
73
+ <p align="middle">
74
+ <img src="images/general_4.gif" width="200" />
75
+ <img src="images/general_5.gif" width="200" />
76
+ <img src="images/general_6.gif" width="200" />
77
+ </p>
78
+
79
+ <p align="middle">
80
+ <b>The above shows a general agent zero-shotting unseen handmade levels.</b>
81
+ </p>
82
+
83
+
84
+ # 📜 Basic Usage
85
+
86
+ Kinetix follows the interfaces established in [gymnax](https://github.com/RobertTLange/gymnax) and [jaxued](https://github.com/DramaCow/jaxued):
87
+
88
+ ```python
89
+ # Use default parameters
90
+ env_params = EnvParams()
91
+ static_env_params = StaticEnvParams()
92
+ ued_params = UEDParams()
93
+
94
+ # Create the environment
95
+ env = make_kinetix_env_from_args(
96
+ obs_type="pixels",
97
+ action_type="multidiscrete",
98
+ reset_type="replay",
99
+ static_env_params=static_env_params,
100
+ )
101
+
102
+ # Sample a random level
103
+ rng = jax.random.PRNGKey(0)
104
+ rng, _rng = jax.random.split(rng)
105
+ level = sample_kinetix_level(_rng, env.physics_engine, env_params, static_env_params, ued_params)
106
+
107
+ # Reset the environment state to this level
108
+ rng, _rng = jax.random.split(rng)
109
+ obs, env_state = env.reset_to_level(_rng, level, env_params)
110
+
111
+ # Take a step in the environment
112
+ rng, _rng = jax.random.split(rng)
113
+ action = env.action_space(env_params).sample(_rng)
114
+ rng, _rng = jax.random.split(rng)
115
+ obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
116
+ ```
117
+
118
+
119
+ # ⬇️ Installation
120
+ To install Kinetix with a CUDA-enabled JAX backend (tested with python3.10):
121
+ ```commandline
122
+ git clone https://github.com/FlairOx/Kinetix.git
123
+ cd Kinetix
124
+ pip install -e .
125
+ pre-commit install
126
+ ```
127
+
128
+ # 🎯 Editor
129
+ We recommend using the [KinetixJS editor](https://kinetix-env.github.io/gallery.html?editor=true), but also provide a native (less polished) Kinetix editor.
130
+
131
+ To open this editor run the following command.
132
+ ```commandline
133
+ python3 kinetix/editor.py
134
+ ```
135
+
136
+ The controls in the editor are:
137
+ - Move between `edit` and `play` modes using `spacebar`
138
+ - In `edit` mode, the type of edit is shown by the icon at the top and is changed by scrolling the mouse wheel. For instance, by navigating to the rectangle editing function you can click to place a rectangle.
139
+ - You can also press the number keys to cycle between modes.
140
+ - To open handmade levels press ctrl-O and navigate to the ones in the L folder.
141
+ - **When playing a level use the arrow keys to control motors and the numeric keys (1, 2) to control thrusters.**
142
+
143
+ # 📈 Experiments
144
+
145
+ We have three primary experiment files,
146
+ 1. [**SFL**](https://github.com/amacrutherford/sampling-for-learnability?tab=readme-ov-file): Training on levels with high learnability, this is how we trained our best general agents.
147
+ 2. **PLR** PLR/DR/ACCEL in the [JAXUED](https://github.com/DramaCow/jaxued) style.
148
+ 3. **PPO** Normal PPO in the [PureJaxRL](https://github.com/luchris429/purejaxrl/) style.
149
+
150
+ To run experiments with default parameters run any of the following:
151
+ ```commandline
152
+ python3 experiments/sfl.py
153
+ python3 experiments/plr.py
154
+ python3 experiments/ppo.py
155
+ ```
156
+
157
+ We use [hydra](https://hydra.cc/) for managing our configs. See the `configs/` folder for all the hydra configs that will be used by default.
158
+ If you want to run experiments with different configurations, you can either edit these configs or pass command line arguments as so:
159
+
160
+ ```commandline
161
+ python3 experiments/sfl.py model.transformer_depth=8
162
+ ```
163
+
164
+ These experiments use [wandb](https://wandb.ai/home) for logging by default.
165
+
166
+ ## 🏋️ Training RL Agents
167
+ We provide several different ways to train RL agents, with the three most common options being, (a) [Training an agent on random levels](#training-on-random-levels), (b) [Training an agent on a single, hand-designed level](#training-on-a-single-hand-designed-level) or (c) [Training an agent on a set of hand-designed levels](#training-on-a-set-of-hand-designed-levels).
168
+
169
+ > [!WARNING]
170
+ > Kinetix has three different environment sizes, `s`, `m` and `l`. When running any of the scripts, you have to set the `env_size` option accordingly, for instance, `python3 experiments/ppo.py train_levels=random env_size=m` would train on random `m` levels.
171
+ > It will give an error if you try and load large levels into a small env size, for instance `python3 experiments/ppo.py train_levels=m env_size=s` would error.
172
+
173
+ ### Training on random levels
174
+ This is the default option, but we give the explicit command for completeness
175
+ ```commandline
176
+ python3 experiments/ppo.py train_levels=random
177
+ ```
178
+ ### Training on a single hand-designed level
179
+
180
+ > [!NOTE]
181
+ > Check the `worlds/` folder for handmade levels for each size category. By default, the loading functions require a relative path to the `worlds/` directory
182
+
183
+ ```commandline
184
+ python3 experiments/ppo.py train_levels=s train_levels.train_levels_list='["s/h4_thrust_aim.json"]'
185
+ ```
186
+ ### Training on a set of hand-designed levels
187
+ ```commandline
188
+ python3 experiments/ppo.py train_levels=s env_size=s eval_env_size=s
189
+ # python3 experiments/ppo.py train_levels=m env_size=m eval_env_size=m
190
+ # python3 experiments/ppo.py train_levels=l env_size=l eval_env_size=l
191
+ ```
192
+
193
+ Or, on a custom set:
194
+ ```commandline
195
+ python3 experiments/ppo.py train_levels=l eval_env_size=l env_size=l train_levels.train_levels_list='["s/h2_one_wheel_car","l/h11_obstacle_avoidance"]'
196
+ ```
197
+
198
+
199
+ # 🔎 See Also
200
+ - 🌐 [Kinetix.js](https://github.com/Michael-Beukman/Kinetix.js) Kinetix reimplemented in Javascript, with a live demo [here](https://kinetix-env.github.io/gallery.html?editor=true).
201
+ - 🍎 [Jax2D](https://github.com/MichaelTMatthews/Jax2D) The physics engine we made for Kinetix.
202
+ - 👨‍💻 [JaxGL](https://github.com/FLAIROx/JaxGL) The graphics library we made for Kinetix.
203
+ - 📋 [Our Paper](https://arxiv.org/abs/2410.23208) for more details and empirical results.
204
+
205
+ # 📚 Citation
206
+ Please cite Kinetix it as follows:
207
+ ```
208
+ @article{matthews2024kinetix,
209
+ title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks},
210
+ author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
211
+ year={2024},
212
+ eprint={2410.23208},
213
+ archivePrefix={arXiv},
214
+ primaryClass={cs.LG},
215
+ url={https://arxiv.org/abs/2410.23208},
216
+ }
217
+ ```
Kinetix/configs/editor.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - env: entity
3
+ - env_size: l
4
+ - learning:
5
+ - ppo-base
6
+ - ppo-rnn
7
+ - misc: misc
8
+ - model:
9
+ - model-base
10
+ - model-transformer
11
+ - _self_
12
+
13
+ seed: 0
14
+ upscale: 2
15
+ downscale: 1
16
+ fps: 60
17
+ debug: true
18
+
19
+ env:
20
+ frame_skip: 1
21
+
22
+ agent_taking_actions: false
Kinetix/configs/env/entity.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ env_name: "Kinetix-Entity-MultiDiscrete-v1"
2
+ dense_reward_scale: 2.0
3
+ frame_skip: 2
Kinetix/configs/env/symbolic.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ env_name: "Kinetix-Symbolic-MultiDiscrete-v1"
2
+ dense_reward_scale: 2.0
3
+ frame_skip: 2
Kinetix/configs/env_size/custom.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ custom_path: worlds/l/grasp_easy.json
2
+ env_size_type: custom
3
+ env_size_name: custom
Kinetix/configs/env_size/l.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ num_polygons: 12
2
+ num_circles: 4
3
+ num_joints: 6
4
+ num_thrusters: 2
5
+ env_size_name: l
6
+ num_motor_bindings: 4
7
+ num_thruster_bindings: 2
8
+ env_size_type: predefined
Kinetix/configs/env_size/m.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ num_polygons: 6
2
+ num_circles: 3
3
+ num_joints: 2
4
+ num_thrusters: 2
5
+ env_size_name: m
6
+ num_motor_bindings: 4
7
+ num_thruster_bindings: 2
8
+ env_size_type: predefined
Kinetix/configs/env_size/s.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ num_polygons: 5
2
+ num_circles: 2
3
+ num_joints: 1
4
+ num_thrusters: 1
5
+ env_size_name: s
6
+ num_motor_bindings: 4
7
+ num_thruster_bindings: 2
8
+ env_size_type: predefined
Kinetix/configs/eval/eval_all.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eval_levels:
2
+ [
3
+ "s/h0_weak_thrust",
4
+ "s/h7_unicycle_left",
5
+ "s/h3_point_the_thruster",
6
+ "s/h4_thrust_aim",
7
+ "s/h1_thrust_over_ball",
8
+ "s/h5_rotate_fall",
9
+ "s/h9_explode_then_thrust_over",
10
+ "s/h6_unicycle_right",
11
+ "s/h8_unicycle_balance",
12
+ "s/h2_one_wheel_car",
13
+
14
+ "m/h0_unicycle",
15
+ "m/h1_car_left",
16
+ "m/h2_car_right",
17
+ "m/h3_car_thrust",
18
+ "m/h4_thrust_the_needle",
19
+ "m/h5_angry_birds",
20
+ "m/h6_thrust_over",
21
+ "m/h7_car_flip",
22
+ "m/h8_weird_vehicle",
23
+ "m/h9_spin_the_right_way",
24
+ "m/h10_thrust_right_easy",
25
+ "m/h11_thrust_left_easy",
26
+ "m/h12_thrustfall_left",
27
+ "m/h13_thrustfall_right",
28
+ "m/h14_thrustblock",
29
+ "m/h15_thrustshoot",
30
+ "m/h16_thrustcontrol_right",
31
+ "m/h17_thrustcontrol_left",
32
+ "m/h18_thrust_right_very_easy",
33
+ "m/h19_thrust_left_very_easy",
34
+ "m/arm_left",
35
+ "m/arm_right",
36
+ "m/arm_up",
37
+ "m/arm_hard",
38
+
39
+ "l/h0_angrybirds",
40
+ "l/h1_car_left",
41
+ "l/h2_car_ramp",
42
+ "l/h3_car_right",
43
+ "l/h4_cartpole",
44
+ "l/h5_flappy_bird",
45
+ "l/h6_lorry",
46
+ "l/h7_maze_1",
47
+ "l/h8_maze_2",
48
+ "l/h9_morph_direction",
49
+ "l/h10_morph_direction_2",
50
+ "l/h11_obstacle_avoidance",
51
+ "l/h12_platformer_1",
52
+ "l/h13_platformer_2",
53
+ "l/h14_simple_thruster",
54
+ "l/h15_swing_up",
55
+ "l/h16_thruster_goal",
56
+ "l/h17_unicycle",
57
+ "l/hard_beam_balance",
58
+ "l/hard_cartpole_thrust",
59
+ "l/hard_cartpole_wheels",
60
+ "l/hard_lunar_lander",
61
+ "l/hard_pinball",
62
+ "l/grasp_hard",
63
+ "l/grasp_easy",
64
+ "l/mjc_half_cheetah",
65
+ "l/mjc_half_cheetah_easy",
66
+ "l/mjc_hopper",
67
+ "l/mjc_hopper_easy",
68
+ "l/mjc_swimmer",
69
+ "l/mjc_walker",
70
+ "l/mjc_walker_easy",
71
+ "l/car_launch",
72
+ "l/car_swing_around",
73
+ "l/chain_lander",
74
+ "l/chain_thrust",
75
+ "l/gears",
76
+ "l/lever_puzzle",
77
+ "l/pr",
78
+ "l/rail",
79
+ ]
80
+ eval_num_attempts: 10
81
+ eval_freq: 10
82
+ EVAL_ON_SAMPLED: false
Kinetix/configs/eval/eval_auto.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ eval_levels: "auto"
2
+ eval_num_attempts: 10
3
+ eval_freq: 10
4
+ EVAL_ON_SAMPLED: false
Kinetix/configs/eval/eval_general.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ eval_levels:
2
+ [
3
+ "easy.simple_thruster",
4
+ ]
5
+ eval_num_attempts: 10
6
+ eval_freq: 10
7
+ EVAL_ON_SAMPLED: false
Kinetix/configs/eval/l.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eval_levels:
2
+ [
3
+ "l/h0_angrybirds",
4
+ "l/h1_car_left",
5
+ "l/h2_car_ramp",
6
+ "l/h3_car_right",
7
+ "l/h4_cartpole",
8
+ "l/h5_flappy_bird",
9
+ "l/h6_lorry",
10
+ "l/h7_maze_1",
11
+ "l/h8_maze_2",
12
+ "l/h9_morph_direction",
13
+ "l/h10_morph_direction_2",
14
+ "l/h11_obstacle_avoidance",
15
+ "l/h12_platformer_1",
16
+ "l/h13_platformer_2",
17
+ "l/h14_simple_thruster",
18
+ "l/h15_swing_up",
19
+ "l/h16_thruster_goal",
20
+ "l/h17_unicycle",
21
+ "l/hard_beam_balance",
22
+ "l/hard_cartpole_thrust",
23
+ "l/hard_cartpole_wheels",
24
+ "l/hard_lunar_lander",
25
+ "l/hard_pinball",
26
+ "l/grasp_hard",
27
+ "l/grasp_easy",
28
+ "l/mjc_half_cheetah",
29
+ "l/mjc_half_cheetah_easy",
30
+ "l/mjc_hopper",
31
+ "l/mjc_hopper_easy",
32
+ "l/mjc_swimmer",
33
+ "l/mjc_walker",
34
+ "l/mjc_walker_easy",
35
+ "l/car_launch",
36
+ "l/car_swing_around",
37
+ "l/chain_lander",
38
+ "l/chain_thrust",
39
+ "l/gears",
40
+ "l/lever_puzzle",
41
+ "l/pr",
42
+ "l/rail",
43
+ ]
44
+ eval_num_attempts: 10
45
+ eval_freq: 50
46
+ EVAL_ON_SAMPLED: true
Kinetix/configs/eval/m.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eval_levels:
2
+ [
3
+ "m/h0_unicycle",
4
+ "m/h1_car_left",
5
+ "m/h2_car_right",
6
+ "m/h3_car_thrust",
7
+ "m/h4_thrust_the_needle",
8
+ "m/h5_angry_birds",
9
+ "m/h6_thrust_over",
10
+ "m/h7_car_flip",
11
+ "m/h8_weird_vehicle",
12
+ "m/h9_spin_the_right_way",
13
+ "m/h10_thrust_right_easy",
14
+ "m/h11_thrust_left_easy",
15
+ "m/h12_thrustfall_left",
16
+ "m/h13_thrustfall_right",
17
+ "m/h14_thrustblock",
18
+ "m/h15_thrustshoot",
19
+ "m/h16_thrustcontrol_right",
20
+ "m/h17_thrustcontrol_left",
21
+ "m/h18_thrust_right_very_easy",
22
+ "m/h19_thrust_left_very_easy",
23
+ "m/arm_left",
24
+ "m/arm_right",
25
+ "m/arm_up",
26
+ "m/arm_hard",
27
+ ]
28
+ eval_num_attempts: 10
29
+ eval_freq: 50
30
+ EVAL_ON_SAMPLED: true
Kinetix/configs/eval/mujoco.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eval_levels:
2
+ [
3
+ "l/mjc_half_cheetah",
4
+ "l/mjc_half_cheetah_easy",
5
+ "l/mjc_hopper",
6
+ "l/mjc_hopper_easy",
7
+ "l/mjc_swimmer",
8
+ "l/mjc_walker",
9
+ "l/mjc_walker_easy",
10
+ ]
11
+ eval_num_attempts: 10
12
+ eval_freq: 10
13
+ EVAL_ON_SAMPLED: false
Kinetix/configs/eval/s.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eval_levels:
2
+ [
3
+ "s/h0_weak_thrust",
4
+ "s/h7_unicycle_left",
5
+ "s/h3_point_the_thruster",
6
+ "s/h4_thrust_aim",
7
+ "s/h1_thrust_over_ball",
8
+ "s/h5_rotate_fall",
9
+ "s/h9_explode_then_thrust_over",
10
+ "s/h6_unicycle_right",
11
+ "s/h8_unicycle_balance",
12
+ "s/h2_one_wheel_car",
13
+ ]
14
+ eval_num_attempts: 10
15
+ eval_freq: 50
16
+ EVAL_ON_SAMPLED: true
Kinetix/configs/eval_env_size/l.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ num_polygons: 12
2
+ num_circles: 4
3
+ num_joints: 6
4
+ num_thrusters: 2
5
+ env_size_name: l
6
+ num_motor_bindings: 4
7
+ num_thruster_bindings: 2
Kinetix/configs/eval_env_size/m.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ num_polygons: 6
2
+ num_circles: 3
3
+ num_joints: 2
4
+ num_thrusters: 2
5
+ env_size_name: m
6
+ num_motor_bindings: 4
7
+ num_thruster_bindings: 2
Kinetix/configs/eval_env_size/s.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ num_polygons: 5
2
+ num_circles: 2
3
+ num_joints: 1
4
+ num_thrusters: 1
5
+ env_size_name: s
6
+ num_motor_bindings: 4
7
+ num_thruster_bindings: 2
Kinetix/configs/learning/ppo-base.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr: 5e-5
2
+ peak_lr: 3e-4
3
+ initial_lr: 1e-5
4
+ warmup_frac: 0.1
5
+ max_grad_norm: 1.0
6
+ total_timesteps: 1073741824
7
+ num_train_envs: 2048
8
+ num_minibatches: 32
9
+ gamma: 0.995
10
+ update_epochs: 8
11
+ clip_eps: 0.2
12
+ gae_lambda: 0.9
13
+ ent_coef: 0.01
14
+ anneal_lr: false
15
+ warmup_lr: false
16
+ vf_coef: 0.5
17
+ permute_state_during_training: false
18
+ filter_levels: true
19
+ level_filter_n_steps: 64
20
+ level_filter_sample_ratio: 2
Kinetix/configs/learning/ppo-rnn.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ num_steps: 64
2
+ num_repeats: 1
Kinetix/configs/learning/ppo-sfl.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ num_steps: 512
Kinetix/configs/learning/ppo-ued.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ num_steps: 64
2
+ outer_rollout_steps: 4
Kinetix/configs/misc/misc.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group: "auto"
2
+ group_auto_prefix: ""
3
+ save_path: "checkpoints/kinetix"
4
+ use_wandb: true
5
+ save_policy: true
6
+ wandb_project: "kinetix-experiments"
7
+ wandb_entity: null
8
+ wandb_mode : online
9
+ video_frequency: 10
10
+ load_from_checkpoint: null
11
+ load_only_params: true
12
+ checkpoint_save_freq: 512
13
+ checkpoint_human_numbers: false
14
+ load_legacy_checkpoint: false
15
+ load_train_levels_legacy: false
16
+ economical_saving: false
Kinetix/configs/model/model-base.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fc_layer_depth: 5
2
+ fc_layer_width: 128
3
+ activation: "tanh"
4
+ recurrent_model: False
Kinetix/configs/model/model-transformer.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformer_depth: 2
2
+ transformer_size: 16
3
+ transformer_encoder_size: 128
4
+ num_heads: 8
5
+ full_attention_mask: false
6
+ aggregate_mode: dummy_and_mean
Kinetix/configs/plr.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - env: entity
3
+ - learning:
4
+ - ppo-base
5
+ - ppo-ued
6
+ - misc: misc
7
+ - env_size: s
8
+ - eval: s
9
+ - eval_env_size: s
10
+ - ued: plr
11
+ - train_levels: random
12
+ - model:
13
+ - model-base
14
+ - model-transformer
15
+ - _self_
16
+
17
+ seed: 0
Kinetix/configs/ppo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - env: entity
3
+ - env_size: s
4
+ - learning:
5
+ - ppo-base
6
+ - ppo-rnn
7
+ - misc: misc
8
+ - eval: s
9
+ - eval_env_size: s
10
+ - train_levels: random
11
+ - model:
12
+ - model-base
13
+ - model-transformer
14
+ - _self_
15
+
16
+
17
+ eval:
18
+ eval_freq: 40
19
+
20
+ seed: 0
Kinetix/configs/sfl.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - env: entity
3
+ - learning:
4
+ - ppo-base
5
+ - ppo-rnn
6
+ - misc: misc
7
+ - ued: sfl
8
+ - env_size: s
9
+ - eval: s
10
+ - eval_env_size: s
11
+ - train_levels: random
12
+ - model:
13
+ - model-base
14
+ - model-transformer
15
+ - _self_
16
+
17
+ eval:
18
+ eval_freq: 128
19
+ learning:
20
+ num_steps: 256
21
+ seed: 0
Kinetix/configs/train_levels/l.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_level_mode: list
2
+ train_levels_list:
3
+ [
4
+ "l/h0_angrybirds",
5
+ "l/h1_car_left",
6
+ "l/h2_car_ramp",
7
+ "l/h3_car_right",
8
+ "l/h4_cartpole",
9
+ "l/h5_flappy_bird",
10
+ "l/h6_lorry",
11
+ "l/h7_maze_1",
12
+ "l/h8_maze_2",
13
+ "l/h9_morph_direction",
14
+ "l/h10_morph_direction_2",
15
+ "l/h11_obstacle_avoidance",
16
+ "l/h12_platformer_1",
17
+ "l/h13_platformer_2",
18
+ "l/h14_simple_thruster",
19
+ "l/h15_swing_up",
20
+ "l/h16_thruster_goal",
21
+ "l/h17_unicycle",
22
+ "l/hard_beam_balance",
23
+ "l/hard_cartpole_thrust",
24
+ "l/hard_cartpole_wheels",
25
+ "l/hard_lunar_lander",
26
+ "l/hard_pinball",
27
+ "l/mjc_half_cheetah",
28
+ "l/mjc_half_cheetah_easy",
29
+ "l/mjc_hopper",
30
+ "l/mjc_hopper_easy",
31
+ "l/mjc_swimmer",
32
+ "l/mjc_walker",
33
+ "l/mjc_walker_easy",
34
+ "l/grasp_hard",
35
+ "l/grasp_easy",
36
+ "l/car_launch",
37
+ "l/car_swing_around",
38
+ "l/chain_lander",
39
+ "l/chain_thrust",
40
+ "l/gears",
41
+ "l/lever_puzzle",
42
+ "l/pr",
43
+ "l/rail",
44
+ ]
Kinetix/configs/train_levels/m.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_level_mode: list
2
+ train_levels_list:
3
+ [
4
+ "m/h0_unicycle",
5
+ "m/h1_car_left",
6
+ "m/h2_car_right",
7
+ "m/h3_car_thrust",
8
+ "m/h4_thrust_the_needle",
9
+ "m/h5_angry_birds",
10
+ "m/h6_thrust_over",
11
+ "m/h7_car_flip",
12
+ "m/h8_weird_vehicle",
13
+ "m/h9_spin_the_right_way",
14
+ "m/h10_thrust_right_easy",
15
+ "m/h11_thrust_left_easy",
16
+ "m/h12_thrustfall_left",
17
+ "m/h13_thrustfall_right",
18
+ "m/h14_thrustblock",
19
+ "m/h15_thrustshoot",
20
+ "m/h16_thrustcontrol_right",
21
+ "m/h17_thrustcontrol_left",
22
+ "m/h18_thrust_right_very_easy",
23
+ "m/h19_thrust_left_very_easy",
24
+ "m/arm_left",
25
+ "m/arm_right",
26
+ "m/arm_up",
27
+ "m/arm_hard",
28
+ ]
Kinetix/configs/train_levels/mujoco.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_level_mode: list
2
+ train_levels_list:
3
+ [
4
+ "l/mjc_half_cheetah",
5
+ "l/mjc_half_cheetah_easy",
6
+ "l/mjc_hopper",
7
+ "l/mjc_hopper_easy",
8
+ "l/mjc_swimmer",
9
+ "l/mjc_walker",
10
+ "l/mjc_walker_easy",
11
+ ]
Kinetix/configs/train_levels/random.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ train_level_mode: random
2
+ train_level_distribution: distribution_v3
Kinetix/configs/train_levels/s.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_level_mode: list
2
+ train_levels_list:
3
+ [
4
+ "s/h0_weak_thrust",
5
+ "s/h7_unicycle_left",
6
+ "s/h3_point_the_thruster",
7
+ "s/h4_thrust_aim",
8
+ "s/h1_thrust_over_ball",
9
+ "s/h5_rotate_fall",
10
+ "s/h9_explode_then_thrust_over",
11
+ "s/h6_unicycle_right",
12
+ "s/h8_unicycle_balance",
13
+ "s/h2_one_wheel_car",
14
+ ]
Kinetix/configs/train_levels/train_all.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_level_mode: list
2
+ train_levels_list:
3
+ [
4
+ "s/h0_weak_thrust",
5
+ "s/h7_unicycle_left",
6
+ "s/h3_point_the_thruster",
7
+ "s/h4_thrust_aim",
8
+ "s/h1_thrust_over_ball",
9
+ "s/h5_rotate_fall",
10
+ "s/h9_explode_then_thrust_over",
11
+ "s/h6_unicycle_right",
12
+ "s/h8_unicycle_balance",
13
+ "s/h2_one_wheel_car",
14
+
15
+ "m/h0_unicycle",
16
+ "m/h1_car_left",
17
+ "m/h2_car_right",
18
+ "m/h3_car_thrust",
19
+ "m/h4_thrust_the_needle",
20
+ "m/h5_angry_birds",
21
+ "m/h6_thrust_over",
22
+ "m/h7_car_flip",
23
+ "m/h8_weird_vehicle",
24
+ "m/h9_spin_the_right_way",
25
+ "m/h10_thrust_right_easy",
26
+ "m/h11_thrust_left_easy",
27
+ "m/h12_thrustfall_left",
28
+ "m/h13_thrustfall_right",
29
+ "m/h14_thrustblock",
30
+ "m/h15_thrustshoot",
31
+ "m/h16_thrustcontrol_right",
32
+ "m/h17_thrustcontrol_left",
33
+ "m/h18_thrust_right_very_easy",
34
+ "m/h19_thrust_left_very_easy",
35
+ "m/arm_left",
36
+ "m/arm_right",
37
+ "m/arm_up",
38
+ "m/arm_hard",
39
+
40
+ "l/h0_angrybirds",
41
+ "l/h1_car_left",
42
+ "l/h2_car_ramp",
43
+ "l/h3_car_right",
44
+ "l/h4_cartpole",
45
+ "l/h5_flappy_bird",
46
+ "l/h6_lorry",
47
+ "l/h7_maze_1",
48
+ "l/h8_maze_2",
49
+ "l/h9_morph_direction",
50
+ "l/h10_morph_direction_2",
51
+ "l/h11_obstacle_avoidance",
52
+ "l/h12_platformer_1",
53
+ "l/h13_platformer_2",
54
+ "l/h14_simple_thruster",
55
+ "l/h15_swing_up",
56
+ "l/h16_thruster_goal",
57
+ "l/h17_unicycle",
58
+ "l/hard_beam_balance",
59
+ "l/hard_cartpole_thrust",
60
+ "l/hard_cartpole_wheels",
61
+ "l/hard_lunar_lander",
62
+ "l/hard_pinball",
63
+ "l/grasp_hard",
64
+ "l/grasp_easy",
65
+ "l/mjc_half_cheetah",
66
+ "l/mjc_half_cheetah_easy",
67
+ "l/mjc_hopper",
68
+ "l/mjc_hopper_easy",
69
+ "l/mjc_swimmer",
70
+ "l/mjc_walker",
71
+ "l/mjc_walker_easy",
72
+ "l/car_launch",
73
+ "l/car_swing_around",
74
+ "l/chain_lander",
75
+ "l/chain_thrust",
76
+ "l/gears",
77
+ "l/lever_puzzle",
78
+ "l/pr",
79
+ "l/rail",
80
+ ]
Kinetix/configs/ued/accel.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use_accel: true
2
+ exploratory_grad_updates: true
3
+ num_edits: 5
4
+ score_function: MaxMC
5
+ level_buffer_capacity: 4000
6
+ replay_prob: 0.5
7
+ staleness_coeff: 0.3
8
+ temperature: 1.0
9
+ topk_k: 8
10
+ minimum_fill_ratio: 0.5
11
+ prioritization: rank
12
+ buffer_duplicate_check: false
13
+ buffer_train: false
14
+ mode: train
15
+ checkpoint_directory: checkpoints/physicsenv/ued
16
+ max_number_of_checkpoints: 5
Kinetix/configs/ued/plr.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use_accel: false
2
+ exploratory_grad_updates: true
3
+ num_edits: 2
4
+ score_function: MaxMC
5
+ level_buffer_capacity: 4000
6
+ replay_prob: 0.5
7
+ staleness_coeff: 0.3
8
+ temperature: 1.0
9
+ topk_k: 8
10
+ minimum_fill_ratio: 0.5
11
+ prioritization: rank
12
+ buffer_duplicate_check: false
13
+ buffer_train: false
14
+ mode: train
15
+ checkpoint_directory: checkpoints/physicsenv/ued
16
+ max_number_of_checkpoints: 5
17
+ accel_start_from_empty: True
Kinetix/configs/ued/sfl.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ "sampled_envs_ratio": 0.5
2
+ "batch_size": 4096
3
+ "num_batches": 3
4
+ "rollout_steps": 512
5
+ "num_to_save": 1024
6
+
7
+ log_learnability_before_after: false
8
+ put_eval_levels_in_buffer: false
9
+ save_learnability_buffer_pickle: false
Kinetix/docs/README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Documentation
2
+ This is intended to provide some more details about how Kinetix works, including more in-depth examples. If you are interested in the configuration options, see [here](./configs.md).
3
+
4
+ - [Documentation](#documentation)
5
+ - [Different Versions of Kinetix Environments](#different-versions-of-kinetix-environments)
6
+ - [Action Spaces](#action-spaces)
7
+ - [Observation Spaces](#observation-spaces)
8
+ - [Resetting Functionality](#resetting-functionality)
9
+ - [Using Kinetix to easily design your own JAX Environments](#using-kinetix-to-easily-design-your-own-jax-environments)
10
+ - [Step 1 - Design an Environment](#step-1---design-an-environment)
11
+ - [Step 2 - Export It](#step-2---export-it)
12
+ - [Step 3 - Import It](#step-3---import-it)
13
+ - [Step 4 - Train](#step-4---train)
14
+
15
+
16
+ ## Different Versions of Kinetix Environments
17
+ We provide several different variations on the standard Kinetix environment, where the primary difference is the action and observation spaces.
18
+
19
+ Each of the environments has a different name, of the following form: `Kinetix-<OBS>-<ACTION>-v1`, and can be made using the `make_kinetix_env_from_name` helper function.
20
+ ### Action Spaces
21
+ For all action spaces, the agent can control joints and thrusters. Joints have a property `motor_binding`, which is a way to tie different joints to the same action. Two joints that have the same binding will always perform the same action, likewise for thrusters.
22
+
23
+ We have three observation spaces, discrete, continuous and multi-discrete (which is the default).
24
+ - **Discrete** has `2 * num_motor_bindings + num_thruster_bindings + 1` options, one of which can be active at any time. There are two options for every joint, i.e., backward and forward at full power. There is one option for each thruster, to activate it at full power. The final option is a no-op, meaning that no torque or force is applied to joints/thrusters.
25
+ - **Continuous** has shape `num_motor_bindings + num_thruster_bindings`, where each motor element can take a value between -1 and 1, and thruster elements can take values between 0 and 1.
26
+ - **Multi-Discrete**: This is a discrete action space, but allows multiple joints and thrusters to be active at any one time. The agent must output a flat vector of size `3 * num_motor_bindings + 2 * num_thruster_bindings`. For joints, each group of three represents a categorical distribution of `[0, -1, +1]` and for thrusters it represents `[0, +1]`.
27
+
28
+ ### Observation Spaces
29
+ We provide three primary observation spaces, Symbolic-Flat (called just symbolic), Symbolic-Entity (called entity, which is also the default) and Pixels.
30
+ - **Symbolic-Flat** returns a large vector, which is the flattened representation of all shapes and their properties.
31
+ - **Symbolic-Entity** also returns a vector representation of all entities, but does not flatten it, instead returning it in a form that can be used with permutation-invariant network architectures, such as transformers.
32
+ - **Pixels** returns an image representation of the scene. This is partially observable, as features such as the restitution and density of shapes is not shown.
33
+
34
+
35
+ Each observation space has its own pros and cons. **Symbolic-Flat** is the fastest by far, but has two clear downsides. First, it is restricted to a single environment size, e.g. a model trained on `small` cannot be run on `medium` levels. Second, due to the large number of symmetries (e.g. any permutation of the same shapes would represent the same scene but would look very different in this observation space), this generalises worse than *entity*.
36
+
37
+ **Symbolic-Entity** is faster than pixels, but slower than Symbolic-Flat. However, it can be applied to any number of shapes, and is natively permutation invariant. For these reasons we chose it as the default option.
38
+
39
+ Finally, **Pixels** runs the slowest, and also requires more memory, which means that we cannot run as many parallel environments. However, pixels is potentially the most general format, and could theoretically allow transfer to other domains and simulators.
40
+
41
+
42
+ ## Resetting Functionality
43
+ We have two primary resetting functions that control the environment's behaviour when an episode ends. The first of these is to train on a known, predefined set of levels, and resetting samples a new level from this set. In the extreme case, this also allows training only on a single level in the standard RL manner. The other main way of resetting is to sample a *random* level from some distribution, meaning that it is exceedingly unlikely to sample the same level twice.
44
+
45
+ ## Using Kinetix to easily design your own JAX Environments
46
+ Since Kinetix has a general physics engine, you can design your own environments and train RL agents on them very fast! This section in the docs describes this pipeline.
47
+ ### Step 1 - Design an Environment
48
+ You can go to our [online editor](https://kinetix-env.github.io/gallery.html?editor=true). You can also have a look at the [gallery](https://kinetix-env.github.io/gallery.html) if you need some inspiration.
49
+
50
+ The following two images show the main editor page, and then the level I designed, where you have to spin the ball the right way. While designing the level, you can play it to test it out, seeing if it is possible and of the appropriate difficulty.
51
+
52
+
53
+ <p align="middle">
54
+ <img src="../images/docs/edit-1.png" width="49%" />
55
+ <img src="../images/docs/edit-2.png" width="49%" />
56
+ </p>
57
+
58
+ ### Step 2 - Export It
59
+ Once you are satisfied with your level, you can download it as a json file by using the button on the bottom left. Once this is downloaded, move it to `$KINETIX_ROOT/worlds/custom/my_custom_level.json`, where `$KINETIX_ROOT` is the root of the Kinetix repo.
60
+
61
+
62
+ ### Step 3 - Import It
63
+ In python, you can import the level as follows, see `examples/example_premade_level_replay.py` for an example.
64
+ ```python
65
+ from kinetix.util.saving import load_from_json_file
66
+ level, static_env_params, env_params = load_from_json_file("worlds/custom/my_custom_level.json")
67
+ ```
68
+
69
+ ### Step 4 - Train
70
+ You can use the above if you want to import the level and play around with it. If you want to train an RL agent on this level, you can do the following (see [here](https://github.com/FLAIROx/Kinetix?tab=readme-ov-file#training-on-a-single-hand-designed-level) from the main README).
71
+
72
+ ```commandline
73
+ python3 experiments/ppo.py env_size=custom \
74
+ env_size.custom_path=custom/my_custom_level.json \
75
+ train_levels=s \
76
+ train_levels.train_levels_list='["custom/my_custom_level.json"]' \
77
+ eval=eval_auto
78
+ ```
79
+
80
+ And the agent will start training, with videos on this on [wandb](https://wandb.ai).
81
+ <p align="middle">
82
+ <img src="../images/docs/wandb.gif" width="49%" />
83
+ </p>
Kinetix/docs/configs.md ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration
2
+
3
+ - [Configuration](#configuration)
4
+ - [Configuration Headings](#configuration-headings)
5
+ - [Env](#env)
6
+ - [Env Size](#env-size)
7
+ - [Learning](#learning)
8
+ - [Misc](#misc)
9
+ - [Eval](#eval)
10
+ - [Eval Env Size](#eval-env-size)
11
+ - [Train Levels](#train-levels)
12
+ - [Model](#model)
13
+ - [UED](#ued)
14
+
15
+
16
+ We use [hydra](hydra.cc) for all of our configurations, and we use [hierarchical configuration](https://hydra.cc/docs/tutorials/structured_config/schema/) to organise everything better.
17
+
18
+ In particular, we have the following configuration headings, with the base `ppo` config looking like:
19
+ ```yaml
20
+ defaults:
21
+ - env: entity
22
+ - env_size: s
23
+ - learning:
24
+ - ppo-base
25
+ - ppo-rnn
26
+ - misc: misc
27
+ - eval: s
28
+ - eval_env_size: s
29
+ - train_levels: random
30
+ - model:
31
+ - model-base
32
+ - model-transformer
33
+ - _self_
34
+ seed: 0
35
+ ```
36
+
37
+ ## Configuration Headings
38
+ ### Env
39
+ This controls the environment to be used.
40
+ #### Preset Options
41
+ We provide two options in `configs/env`, namely `entity` and `symbolic`; each of these can be used by running `python3 experiments/ppo.py env=symbolic` or `python3 experiments/ppo.py env=entity`. If you wish to customise the options further, you can add any of the following subkeys (e.g. by running `python3 experiments/ppo.py env=symbolic env.dense_reward_scale=0.0`):
42
+ #### Individual Subkeys
43
+ - `env.env_name`: The name of the environment, with controls the observation and action space.
44
+ - `env.dense_reward_scale`: How large the dense reward scale is, set this to zero to disable dense rewards.
45
+ - `env.frame_skip`: The number of frames to skip, setting this to 2 (the default) seems to perform better.
46
+ ### Env Size
47
+ This controls the maximum number of shapes present in the simulation. This has two important tradeoffs, namely speed and representational power: Small environments run much faster but some complex environments require a large number of shapes. See `configs/env_size`
48
+ #### Preset Options
49
+ - `s`: The `small` preset
50
+ - `m`: `Medium` preset
51
+ - `l`: `Large` preset
52
+ - `custom`: Allows the use of a custom environment size loaded from a json file (see [here](#train-levels) for more).
53
+ #### Individual Subkeys
54
+ - `num_polygons`: How many polygons
55
+ - `num_circles`: How many circles
56
+ - `num_joints`: How many joints
57
+ - `num_thrusters`: How many thrusters
58
+ - `env_size_name`: "s", "m" or "l"
59
+ - `num_motor_bindings`: How many different joint bindings are there, meaning how many different actions are there associated with joints. All joints with the same binding will have the same action applied to them.
60
+ - `num_thruster_bindings`: How many different thruster bindings are there
61
+ - `env_size_type`: "predefined" or "custom"
62
+ - `custom_path`: **Only for env_size_type=custom**, controls the json file to load the custom environment size from.
63
+ ### Learning
64
+ This controls the agent's learning, see `configs/learning`
65
+ #### Preset Options
66
+ - `ppo-base`: This has all of the base PPO parameters, and is used by all methods
67
+ - `ppo-rnn`: This has the PureJaxRL settings for some of PPO's hyperparameters (mainly `num_steps` is different)
68
+ - `ppo-sfl`: This has the SFL-specific value of `num_steps`
69
+ - `ppo-ued`: This has the JAXUED-specific `num_steps` and `outer_rollout_steps`
70
+ #### Individual Subkeys
71
+ - `lr`: Learning Rate
72
+ - `anneal_lr`: Whether to anneal LR
73
+ - `warmup_lr`: Whether to warmup LR
74
+ - `peak_lr`: If warming up, the peak
75
+ - `initial_lr`: If warming up, the initial LR
76
+ - `warmup_frac`: If warming up, the warmup fraction of training time
77
+ - `max_grad_norm`: Maximum grad norm
78
+ - `total_timesteps`: How many total environment interactions must be run
79
+ - `num_train_envs`: Number of parallel environments to run simultaneously
80
+ - `num_minibatches`: Minibatches for PPO learning
81
+ - `gamma`: Discount factor
82
+ - `update_epochs`: PPO update epochs
83
+ - `clip_eps`: PPO clipping epsilon
84
+ - `gae_lambda`: PPO Lambda for GAE
85
+ - `ent_coef`: Entropy loss coefficient
86
+ - `vf_coef`: Value function loss coefficient
87
+ - `permute_state_during_training`: If true, the state is permuted on every reset.
88
+ - `filter_levels`: If true, and we are training on random levels, this filters out levels that can be solved by a no-op
89
+ - `level_filter_n_steps`: How many steps to allocate to the no-op policy for filtering
90
+ - `level_filter_sample_ratio`: How many more levels to sample than required (ideally `level_filter_sample_ratio` is more than the fraction that will be filtered out).
91
+ - `num_steps`: PPO rollout length
92
+ - `outer_rollout_steps`: How many learning steps to do for e.g. PLR for each rollout (see the [Craftax paper](https://arxiv.org/abs/2402.16801) for a more in-depth explanation).
93
+ ### Misc
94
+ There are a plethora of miscellaneous options that are grouped under the `misc` category. There is only one preset option, `configs/misc/misc.yaml`.
95
+ #### Individual Subkeys
96
+ - `group`: Wandb group ("auto" usually works well)
97
+ - `group_auto_prefix`: If using group=auto, this is a user-defined prefix
98
+ - `save_path`: Where to save checkpoints to
99
+ - `use_wandb`: Should wandb be logged to
100
+ - `save_policy`: Should we save the policy
101
+ - `wandb_project`: Wandb project
102
+ - `wandb_entity`: Wandb entity, leave as `null` to use your default one
103
+ - `wandb_mode` : Wandb mode
104
+ - `video_frequency`: How often to log videos (they are quite large)
105
+ - `load_from_checkpoint`: WWandb artifact path to load from
106
+ - `load_only_params`: Whether to load just the network parameters or entire train state.
107
+ - `checkpoint_save_freq`: How often to log checkpoits
108
+ - `checkpoint_human_numbers`: Should the checkpoints have human-readable timestep numbers
109
+ - `load_legacy_checkpoint`: Do not use
110
+ - `load_train_levels_legacy`: Do not use
111
+ - `economical_saving`: If true, only saves a few important checkpoints for space conservation purposes.
112
+ ### Eval
113
+ This option (see `configs/eval`) controls how evaluation works, and what levels are used.
114
+ #### Preset Options
115
+ - `s`: Eval on the `s` hand-designed levels located in `worlds/s`
116
+ - `m`: Eval on the `m` hand-designed levels located in `worlds/m`
117
+ - `l`: Eval on the `l` hand-designed levels located in `worlds/l`
118
+ - `eval_all`: Eval on all of the hand-designed eval levels
119
+ - `eval_auto`: If `train_levels` is not random, evaluate on the training levels.
120
+ - `mujoco`: Eval on the recreations of the mujoco tasks.
121
+ - `eval_general`: General option if you are planning on overwriting most options.
122
+ #### Individual Subkeys
123
+ - `eval_levels`: List of eval levels or the string "auto"
124
+ - `eval_num_attempts`: How many times to eval on the same level
125
+ - `eval_freq`: How often to evaluate
126
+ - `EVAL_ON_SAMPLED`: If true, in `plr.py` and `sfl.py`, evaluates on a fixed set of randomly-generated levels
127
+
128
+ ### Eval Env Size
129
+ This controls the size of the evaluation environment. This is crucial to match up with the size of the evaluation levels.
130
+ #### Preset Options
131
+ - `s`: Same as the `env_size` option.
132
+ - `m`: Same as the `env_size` option.
133
+ - `l`: Same as the `env_size` option.
134
+ ### Train Levels
135
+ Which levels to train on.
136
+ #### Preset Options
137
+ - `s`: All of the `s` holdout levels
138
+ - `m`: All of the `m` holdout levels
139
+ - `l`: All of the `l` holdout levels
140
+ - `train_all`: All of the levels from all 3 holdout sets
141
+ - `mujoco`: All of the mujoco recreation levels.
142
+ - `random`: Train on random levels
143
+ #### Individual Subkeys
144
+ - `train_level_mode`: "random" or "list"
145
+ - `train_level_distribution`: if train_level_mode=random, this controls which distribution to use. By default `distribution_v3`
146
+ - `train_levels_list`: This is a list of levels to train on.
147
+ ### Model
148
+ This controls the model architecture and options associated with that.
149
+ #### Preset Options
150
+ We use both of the following:
151
+ - `model-base`
152
+ - `model-entity`
153
+ #### Individual Subkeys
154
+ `fc_layer_depth`: How many layers in the FC model
155
+ `fc_layer_width`: How wide is each FC layer
156
+ `activation`: NN activation
157
+ `recurrent_model`: Whether or not to use recurrence
158
+ The following are just relevant when using `env=entity`
159
+ `transformer_depth`: How many transformer layers to use
160
+ `transformer_size`: How large are the KQV vectors
161
+ `transformer_encoder_size`: How large are the initial embeddings
162
+ `num_heads`: How many heads, must be a multiple of 4 and divide `transformer_size` evenly.
163
+ `full_attention_mask`: If true, all heads use the full attention mask
164
+ `aggregate_mode`: `dummy_and_mean` works well.
165
+ ### UED
166
+ Options pertaining to UED (i.e., when using the scripts `plr.py` or `sfl.py`)
167
+ #### Preset Options
168
+ - `sfl`
169
+ - `plr`
170
+ - `accel`
171
+ #### Individual Subkeys
172
+ See the individual files for the configuration options used.
173
+ For SFL, we have:
174
+
175
+ - `sampled_envs_ratio`: How many environments are from the SFL buffer and how many are randomly generated
176
+ - `batch_size`: How many levels to evaluate learnability on per batch
177
+ - `num_batches`: How many batches to run when choosing the most learnable levels
178
+ - `rollout_steps`: How many steps to rollout for when doing the learnability calculation.
179
+ - `num_to_save`: How many levels to save in the learnability buffer
Kinetix/examples/example_premade_level_replay.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jax.random
4
+ from jax2d.engine import PhysicsEngine
5
+ from matplotlib import pyplot as plt
6
+
7
+ from kinetix.environment.env import make_kinetix_env_from_args
8
+ from kinetix.environment.env_state import StaticEnvParams, EnvParams
9
+ from kinetix.environment.ued.distributions import sample_kinetix_level
10
+ from kinetix.environment.ued.ued_state import UEDParams
11
+ from kinetix.render.renderer_pixels import make_render_pixels
12
+ from kinetix.util.saving import load_from_json_file
13
+
14
+
15
+ def main():
16
+ # Load a premade level
17
+ level, static_env_params, env_params = load_from_json_file("worlds/l/grasp_easy.json")
18
+
19
+ # Create the environment
20
+ env = make_kinetix_env_from_args(
21
+ obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
22
+ )
23
+
24
+ # Reset the environment state to this level
25
+ rng = jax.random.PRNGKey(0)
26
+ rng, _rng = jax.random.split(rng)
27
+ obs, env_state = env.reset_to_level(_rng, level, env_params)
28
+
29
+ # Take a step in the environment
30
+ rng, _rng = jax.random.split(rng)
31
+ action = env.action_space(env_params).sample(_rng)
32
+ rng, _rng = jax.random.split(rng)
33
+ obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
34
+
35
+ # Render environment
36
+ renderer = make_render_pixels(env_params, static_env_params)
37
+
38
+ # There are a lot of wrappers
39
+ pixels = renderer(env_state.env_state.env_state.env_state)
40
+
41
+ plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
42
+ plt.show()
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
Kinetix/examples/example_random_level_replay.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jax.random
4
+ from jax2d.engine import PhysicsEngine
5
+ from matplotlib import pyplot as plt
6
+
7
+ from kinetix.environment.env import make_kinetix_env_from_args
8
+ from kinetix.environment.env_state import StaticEnvParams, EnvParams
9
+ from kinetix.environment.ued.distributions import sample_kinetix_level
10
+ from kinetix.environment.ued.ued_state import UEDParams
11
+ from kinetix.render.renderer_pixels import make_render_pixels
12
+
13
+
14
+ def main():
15
+ # Use default parameters
16
+ env_params = EnvParams()
17
+ static_env_params = StaticEnvParams()
18
+ ued_params = UEDParams()
19
+
20
+ # Create the environment
21
+ env = make_kinetix_env_from_args(
22
+ obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
23
+ )
24
+
25
+ # Sample a random level
26
+ rng = jax.random.PRNGKey(0)
27
+ rng, _rng = jax.random.split(rng)
28
+ level = sample_kinetix_level(_rng, env.physics_engine, env_params, static_env_params, ued_params)
29
+
30
+ # Reset the environment state to this level
31
+ rng, _rng = jax.random.split(rng)
32
+ obs, env_state = env.reset_to_level(_rng, level, env_params)
33
+
34
+ # Take a step in the environment
35
+ rng, _rng = jax.random.split(rng)
36
+ action = env.action_space(env_params).sample(_rng)
37
+ rng, _rng = jax.random.split(rng)
38
+ obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
39
+
40
+ # Render environment
41
+ renderer = make_render_pixels(env_params, static_env_params)
42
+
43
+ # There are a lot of wrappers
44
+ pixels = renderer(env_state.env_state.env_state.env_state)
45
+
46
+ plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
47
+ plt.show()
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
Kinetix/experiments/plr.py ADDED
@@ -0,0 +1,1143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import time
3
+ from enum import IntEnum
4
+ from typing import Tuple
5
+
6
+ import chex
7
+ import hydra
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+ from omegaconf import OmegaConf
12
+ import optax
13
+ from flax import core, struct
14
+ from flax.training.train_state import TrainState as BaseTrainState
15
+
16
+ import wandb
17
+ from kinetix.environment.ued.distributions import (
18
+ create_random_starting_distribution,
19
+ )
20
+ from kinetix.environment.ued.ued import (
21
+ make_mutate_env,
22
+ make_reset_train_function_with_mutations,
23
+ make_vmapped_filtered_level_sampler,
24
+ )
25
+ from kinetix.environment.ued.ued import (
26
+ make_mutate_env,
27
+ make_reset_train_function_with_list_of_levels,
28
+ make_reset_train_function_with_mutations,
29
+ )
30
+ from kinetix.util.config import (
31
+ generate_ued_params_from_config,
32
+ get_video_frequency,
33
+ init_wandb,
34
+ normalise_config,
35
+ save_data_to_local_file,
36
+ generate_params_from_config,
37
+ get_eval_level_groups,
38
+ )
39
+ from jaxued.environments.underspecified_env import EnvState
40
+ from jaxued.level_sampler import LevelSampler
41
+ from jaxued.utils import compute_max_returns, max_mc, positive_value_loss
42
+ from flax.serialization import to_state_dict
43
+
44
+ import sys
45
+
46
+ sys.path.append("experiments")
47
+ from kinetix.environment.env import make_kinetix_env_from_name
48
+ from kinetix.environment.env_state import StaticEnvParams
49
+ from kinetix.environment.wrappers import (
50
+ UnderspecifiedToGymnaxWrapper,
51
+ LogWrapper,
52
+ DenseRewardWrapper,
53
+ AutoReplayWrapper,
54
+ )
55
+ from kinetix.models import make_network_from_config
56
+ from kinetix.render.renderer_pixels import make_render_pixels
57
+ from kinetix.models.actor_critic import ScannedRNN
58
+ from kinetix.util.learning import (
59
+ general_eval,
60
+ get_eval_levels,
61
+ no_op_and_random_rollout,
62
+ sample_trajectories_and_learn,
63
+ )
64
+ from kinetix.util.saving import (
65
+ load_train_state_from_wandb_artifact_path,
66
+ save_model_to_wandb,
67
+ )
68
+
69
+
70
+ class UpdateState(IntEnum):
71
+ DR = 0
72
+ REPLAY = 1
73
+ MUTATE = 2
74
+
75
+
76
+ def get_level_complexity_metrics(all_levels: EnvState, static_env_params: StaticEnvParams):
77
+ def get_for_single_level(level):
78
+ return {
79
+ "complexity/num_shapes": level.polygon.active[static_env_params.num_static_fixated_polys :].sum()
80
+ + level.circle.active.sum(),
81
+ "complexity/num_joints": level.joint.active.sum(),
82
+ "complexity/num_thrusters": level.thruster.active.sum(),
83
+ "complexity/num_rjoints": (level.joint.active * jnp.logical_not(level.joint.is_fixed_joint)).sum(),
84
+ "complexity/num_fjoints": (level.joint.active * (level.joint.is_fixed_joint)).sum(),
85
+ "complexity/has_ball": ((level.polygon_shape_roles == 1) * level.polygon.active).sum()
86
+ + ((level.circle_shape_roles == 1) * level.circle.active).sum(),
87
+ "complexity/has_goal": ((level.polygon_shape_roles == 2) * level.polygon.active).sum()
88
+ + ((level.circle_shape_roles == 2) * level.circle.active).sum(),
89
+ }
90
+
91
+ return jax.tree.map(lambda x: x.mean(), jax.vmap(get_for_single_level)(all_levels))
92
+
93
+
94
+ def get_ued_score_metrics(all_ued_scores):
95
+ (mc, pvl, learn) = all_ued_scores
96
+ scores = {}
97
+ for score, name in zip([mc, pvl, learn], ["MaxMC", "PVL", "Learnability"]):
98
+ scores[f"ued_scores/{name}/Mean"] = score.mean()
99
+ scores[f"ued_scores_additional/{name}/Max"] = score.max()
100
+ scores[f"ued_scores_additional/{name}/Min"] = score.min()
101
+
102
+ return scores
103
+
104
+
105
+ class TrainState(BaseTrainState):
106
+ sampler: core.FrozenDict[str, chex.ArrayTree] = struct.field(pytree_node=True)
107
+ update_state: UpdateState = struct.field(pytree_node=True)
108
+ # === Below is used for logging ===
109
+ num_dr_updates: int
110
+ num_replay_updates: int
111
+ num_mutation_updates: int
112
+
113
+ dr_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
114
+ replay_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
115
+ mutation_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
116
+
117
+ dr_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
118
+ replay_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
119
+ mutation_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
120
+
121
+ dr_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
122
+ replay_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
123
+ mutation_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
124
+
125
+
126
+ # region PPO helper functions
127
+
128
+ # endregion
129
+
130
+
131
+ def train_state_to_log_dict(train_state: TrainState, level_sampler: LevelSampler) -> dict:
132
+ """To prevent the entire (large) train_state to be copied to the CPU when doing logging, this function returns all of the important information in a dictionary format.
133
+
134
+ Anything in the `log` key will be logged to wandb.
135
+
136
+ Args:
137
+ train_state (TrainState):
138
+ level_sampler (LevelSampler):
139
+
140
+ Returns:
141
+ dict:
142
+ """
143
+ sampler = train_state.sampler
144
+ idx = jnp.arange(level_sampler.capacity) < sampler["size"]
145
+ s = jnp.maximum(idx.sum(), 1)
146
+ return {
147
+ "log": {
148
+ "level_sampler/size": sampler["size"],
149
+ "level_sampler/episode_count": sampler["episode_count"],
150
+ "level_sampler/max_score": sampler["scores"].max(),
151
+ "level_sampler/weighted_score": (sampler["scores"] * level_sampler.level_weights(sampler)).sum(),
152
+ "level_sampler/mean_score": (sampler["scores"] * idx).sum() / s,
153
+ },
154
+ "info": {
155
+ "num_dr_updates": train_state.num_dr_updates,
156
+ "num_replay_updates": train_state.num_replay_updates,
157
+ "num_mutation_updates": train_state.num_mutation_updates,
158
+ },
159
+ }
160
+
161
+
162
+ def compute_learnability(config, done, reward, info, num_envs):
163
+ num_agents = 1
164
+ BATCH_ACTORS = num_envs * num_agents
165
+
166
+ rollout_length = config["num_steps"] * config["outer_rollout_steps"]
167
+
168
+ @partial(jax.vmap, in_axes=(None, 1, 1, 1))
169
+ @partial(jax.jit, static_argnums=(0,))
170
+ def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
171
+ idxs = jnp.arange(max_steps)
172
+
173
+ @partial(jax.vmap, in_axes=(0, 0))
174
+ def __ep_outcomes(start_idx, end_idx):
175
+ mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
176
+ r = jnp.sum(returns * mask)
177
+ goal_r = info["GoalR"]
178
+ success = jnp.sum(goal_r * mask)
179
+ collision = 0
180
+ timeo = 0
181
+ l = end_idx - start_idx
182
+ return r, success, collision, timeo, l
183
+
184
+ done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
185
+ mask_done = jnp.where(done_idxs == max_steps, 0, 1)
186
+ ep_return, success, collision, timeo, length = __ep_outcomes(
187
+ jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
188
+ )
189
+
190
+ return {
191
+ "ep_return": ep_return.mean(where=mask_done),
192
+ "num_episodes": mask_done.sum(),
193
+ "num_success": success.sum(where=mask_done),
194
+ "success_rate": success.mean(where=mask_done),
195
+ "collision_rate": collision.mean(where=mask_done),
196
+ "timeout_rate": timeo.mean(where=mask_done),
197
+ "ep_len": length.mean(where=mask_done),
198
+ }
199
+
200
+ done_by_env = done.reshape((-1, num_agents, num_envs))
201
+ reward_by_env = reward.reshape((-1, num_agents, num_envs))
202
+ o = _calc_outcomes_by_agent(rollout_length, done, reward, info)
203
+ success_by_env = o["success_rate"].reshape((num_agents, num_envs))
204
+ learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
205
+
206
+ return (
207
+ learnability_by_env,
208
+ o["num_episodes"].reshape(num_agents, num_envs).sum(axis=0),
209
+ o["num_success"].reshape(num_agents, num_envs).T,
210
+ ) # so agents is at the end.
211
+
212
+
213
+ def compute_score(
214
+ config: dict, dones: chex.Array, values: chex.Array, max_returns: chex.Array, reward, info, advantages: chex.Array
215
+ ) -> chex.Array:
216
+ # Computes the score for each level
217
+ if config["score_function"] == "MaxMC":
218
+ return max_mc(dones, values, max_returns)
219
+ elif config["score_function"] == "pvl":
220
+ return positive_value_loss(dones, advantages)
221
+ elif config["score_function"] == "learnability":
222
+ learnability, num_episodes, num_success = compute_learnability(
223
+ config, dones, reward, info, config["num_train_envs"]
224
+ )
225
+ return learnability
226
+ else:
227
+ raise ValueError(f"Unknown score function: {config['score_function']}")
228
+
229
+
230
+ def compute_all_scores(
231
+ config: dict,
232
+ dones: chex.Array,
233
+ values: chex.Array,
234
+ max_returns: chex.Array,
235
+ reward,
236
+ info,
237
+ advantages: chex.Array,
238
+ return_success_rate=False,
239
+ ):
240
+ mc = max_mc(dones, values, max_returns)
241
+ pvl = positive_value_loss(dones, advantages)
242
+ learnability, num_episodes, num_success = compute_learnability(
243
+ config, dones, reward, info, config["num_train_envs"]
244
+ )
245
+ if config["score_function"] == "MaxMC":
246
+ main_score = mc
247
+ elif config["score_function"] == "pvl":
248
+ main_score = pvl
249
+ elif config["score_function"] == "learnability":
250
+ main_score = learnability
251
+ else:
252
+ raise ValueError(f"Unknown score function: {config['score_function']}")
253
+ if return_success_rate:
254
+ success_rate = num_success.squeeze(1) / jnp.maximum(num_episodes, 1)
255
+ return main_score, (mc, pvl, learnability, success_rate)
256
+ return main_score, (mc, pvl, learnability)
257
+
258
+
259
+ @hydra.main(version_base=None, config_path="../configs", config_name="plr")
260
+ def main(config=None):
261
+ my_name = "PLR"
262
+ config = OmegaConf.to_container(config)
263
+ if config["ued"]["replay_prob"] == 0.0:
264
+ my_name = "DR"
265
+ elif config["ued"]["use_accel"]:
266
+ my_name = "ACCEL"
267
+
268
+ time_start = time.time()
269
+ config = normalise_config(config, my_name)
270
+ env_params, static_env_params = generate_params_from_config(config)
271
+ config["env_params"] = to_state_dict(env_params)
272
+ config["static_env_params"] = to_state_dict(static_env_params)
273
+
274
+ run = init_wandb(config, my_name)
275
+ config = wandb.config
276
+ time_prev = time.time()
277
+
278
+ def log_eval(stats, train_state_info):
279
+ nonlocal time_prev
280
+ print(f"Logging update: {stats['update_count']}")
281
+ total_loss = jnp.mean(stats["losses"][0])
282
+ if jnp.isnan(total_loss):
283
+ print("NaN loss, skipping logging")
284
+ raise ValueError("NaN loss")
285
+
286
+ # generic stats
287
+ env_steps = int(
288
+ int(stats["update_count"]) * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
289
+ )
290
+ env_steps_delta = (
291
+ config["eval_freq"] * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
292
+ )
293
+ time_now = time.time()
294
+ log_dict = {
295
+ "timing/num_updates": stats["update_count"],
296
+ "timing/num_env_steps": env_steps,
297
+ "timing/sps": env_steps_delta / (time_now - time_prev),
298
+ "timing/sps_agg": env_steps / (time_now - time_start),
299
+ "loss/total_loss": jnp.mean(stats["losses"][0]),
300
+ "loss/value_loss": jnp.mean(stats["losses"][1][0]),
301
+ "loss/policy_loss": jnp.mean(stats["losses"][1][1]),
302
+ "loss/entropy_loss": jnp.mean(stats["losses"][1][2]),
303
+ }
304
+ time_prev = time_now
305
+
306
+ # evaluation performance
307
+
308
+ returns = stats["eval_returns"]
309
+ log_dict.update({"eval/mean_eval_return": returns.mean()})
310
+ log_dict.update({"eval/mean_eval_learnability": stats["eval_learn"].mean()})
311
+ log_dict.update({"eval/mean_eval_solve_rate": stats["eval_solves"].mean()})
312
+ log_dict.update({"eval/mean_eval_eplen": stats["eval_ep_lengths"].mean()})
313
+ for i in range(config["num_eval_levels"]):
314
+ log_dict[f"eval_avg_return/{config['eval_levels'][i]}"] = returns[i]
315
+ log_dict[f"eval_avg_learnability/{config['eval_levels'][i]}"] = stats["eval_learn"][i]
316
+ log_dict[f"eval_avg_solve_rate/{config['eval_levels'][i]}"] = stats["eval_solves"][i]
317
+ log_dict[f"eval_avg_episode_length/{config['eval_levels'][i]}"] = stats["eval_ep_lengths"][i]
318
+ log_dict[f"eval_get_max_eplen/{config['eval_levels'][i]}"] = stats["eval_get_max_eplen"][i]
319
+ log_dict[f"episode_return_bigger_than_negative/{config['eval_levels'][i]}"] = stats[
320
+ "episode_return_bigger_than_negative"
321
+ ][i]
322
+
323
+ def _aggregate_per_size(values, name):
324
+ to_return = {}
325
+ for group_name, indices in eval_group_indices.items():
326
+ to_return[f"{name}_{group_name}"] = values[indices].mean()
327
+ return to_return
328
+
329
+ log_dict.update(_aggregate_per_size(returns, "eval_aggregate/return"))
330
+ log_dict.update(_aggregate_per_size(stats["eval_solves"], "eval_aggregate/solve_rate"))
331
+
332
+ if config["EVAL_ON_SAMPLED"]:
333
+ log_dict.update({"eval/mean_eval_return_sampled": stats["eval_dr_returns"].mean()})
334
+ log_dict.update({"eval/mean_eval_solve_rate_sampled": stats["eval_dr_solve_rates"].mean()})
335
+ log_dict.update({"eval/mean_eval_eplen_sampled": stats["eval_dr_eplen"].mean()})
336
+
337
+ # level sampler
338
+ log_dict.update(train_state_info["log"])
339
+
340
+ # images
341
+ log_dict.update(
342
+ {
343
+ "images/highest_scoring_level": wandb.Image(
344
+ np.array(stats["highest_scoring_level"]), caption="Highest scoring level"
345
+ )
346
+ }
347
+ )
348
+ log_dict.update(
349
+ {
350
+ "images/highest_weighted_level": wandb.Image(
351
+ np.array(stats["highest_weighted_level"]), caption="Highest weighted level"
352
+ )
353
+ }
354
+ )
355
+
356
+ for s in ["dr", "replay", "mutation"]:
357
+ if train_state_info["info"][f"num_{s}_updates"] > 0:
358
+ log_dict.update(
359
+ {
360
+ f"images/{s}_levels": [
361
+ wandb.Image(np.array(image), caption=f"{score}")
362
+ for image, score in zip(stats[f"{s}_levels"], stats[f"{s}_scores"])
363
+ ]
364
+ }
365
+ )
366
+ if stats["log_videos"]:
367
+ # animations
368
+ rollout_ep = stats[f"{s}_ep_len"]
369
+ arr = np.array(stats[f"{s}_rollout"][:rollout_ep])
370
+ log_dict.update(
371
+ {
372
+ f"media/{s}_eval": wandb.Video(
373
+ arr.astype(np.uint8), fps=15, caption=f"{s.capitalize()} (len {rollout_ep})"
374
+ )
375
+ }
376
+ )
377
+ # * 255
378
+
379
+ # DR, Replay and Mutate Returns
380
+ dr_inds = (stats["update_state"] == UpdateState.DR).nonzero()[0]
381
+ rep_inds = (stats["update_state"] == UpdateState.REPLAY).nonzero()[0]
382
+ mut_inds = (stats["update_state"] == UpdateState.MUTATE).nonzero()[0]
383
+
384
+ for name, inds in [
385
+ ("DR", dr_inds),
386
+ ("REPLAY", rep_inds),
387
+ ("MUTATION", mut_inds),
388
+ ]:
389
+ if len(inds) > 0:
390
+ log_dict.update(
391
+ {
392
+ f"{name}/episode_return": stats["episode_return"][inds].mean(),
393
+ f"{name}/mean_eplen": stats["returned_episode_lengths"][inds].mean(),
394
+ f"{name}/mean_success": stats["returned_episode_solved"][inds].mean(),
395
+ f"{name}/noop_return": stats["noop_returns"][inds].mean(),
396
+ f"{name}/noop_eplen": stats["noop_eplen"][inds].mean(),
397
+ f"{name}/noop_success": stats["noop_success"][inds].mean(),
398
+ f"{name}/random_return": stats["random_returns"][inds].mean(),
399
+ f"{name}/random_eplen": stats["random_eplen"][inds].mean(),
400
+ f"{name}/random_success": stats["random_success"][inds].mean(),
401
+ }
402
+ )
403
+ for k in stats:
404
+ if "complexity/" in k:
405
+ k2 = "complexity/" + name + "_" + k.replace("complexity/", "")
406
+ log_dict.update({k2: stats[k][inds].mean()})
407
+ if "ued_scores/" in k:
408
+ k2 = "ued_scores/" + name + "_" + k.replace("ued_scores/", "")
409
+ log_dict.update({k2: stats[k][inds].mean()})
410
+
411
+ # Eval rollout animations
412
+ if stats["log_videos"]:
413
+ for i in range((config["num_eval_levels"])):
414
+ frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
415
+ frames = np.array(frames[:episode_length])
416
+ log_dict.update(
417
+ {
418
+ f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
419
+ frames.astype(np.uint8), fps=15, caption=f"Len ({episode_length})"
420
+ )
421
+ }
422
+ )
423
+
424
+ wandb.log(log_dict)
425
+
426
+ def get_all_metrics(
427
+ rng,
428
+ losses,
429
+ info,
430
+ init_env_state,
431
+ init_obs,
432
+ dones,
433
+ grads,
434
+ all_ued_scores,
435
+ new_levels,
436
+ ):
437
+ noop_returns, noop_len, noop_success, random_returns, random_lens, random_success = no_op_and_random_rollout(
438
+ env,
439
+ env_params,
440
+ rng,
441
+ init_obs,
442
+ init_env_state,
443
+ config["num_train_envs"],
444
+ config["num_steps"] * config["outer_rollout_steps"],
445
+ )
446
+ metrics = (
447
+ {
448
+ "losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),
449
+ "returned_episode_lengths": (info["returned_episode_lengths"] * dones).sum()
450
+ / jnp.maximum(1, dones.sum()),
451
+ "max_episode_length": info["returned_episode_lengths"].max(),
452
+ "levels_played": init_env_state.env_state.env_state,
453
+ "episode_return": (info["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
454
+ "episode_return_v2": (info["returned_episode_returns"] * info["returned_episode"]).sum()
455
+ / jnp.maximum(1, info["returned_episode"].sum()),
456
+ "grad_norms": grads.mean(),
457
+ "noop_returns": noop_returns,
458
+ "noop_eplen": noop_len,
459
+ "noop_success": noop_success,
460
+ "random_returns": random_returns,
461
+ "random_eplen": random_lens,
462
+ "random_success": random_success,
463
+ "returned_episode_solved": (info["returned_episode_solved"] * dones).sum()
464
+ / jnp.maximum(1, dones.sum()),
465
+ }
466
+ | get_level_complexity_metrics(new_levels, static_env_params)
467
+ | get_ued_score_metrics(all_ued_scores)
468
+ )
469
+ return metrics
470
+
471
+ # Setup the environment.
472
+ def make_env(static_env_params):
473
+ env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
474
+ env = AutoReplayWrapper(env)
475
+ env = UnderspecifiedToGymnaxWrapper(env)
476
+ env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
477
+ env = LogWrapper(env)
478
+ return env
479
+
480
+ env = make_env(static_env_params)
481
+
482
+ if config["train_level_mode"] == "list":
483
+ sample_random_level = make_reset_train_function_with_list_of_levels(
484
+ config, config["train_levels_list"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
485
+ )
486
+ elif config["train_level_mode"] == "random":
487
+ sample_random_level = make_reset_train_function_with_mutations(
488
+ env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
489
+ )
490
+ else:
491
+ raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
492
+
493
+ if config["use_accel"] and config["accel_start_from_empty"]:
494
+
495
+ def make_sample_random_level():
496
+ def inner(rng):
497
+ def _inner_accel(rng):
498
+ return create_random_starting_distribution(
499
+ rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=True
500
+ )
501
+
502
+ def _inner_accel_not_controllable(rng):
503
+ return create_random_starting_distribution(
504
+ rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=False
505
+ )
506
+
507
+ rng, _rng = jax.random.split(rng)
508
+ return _inner_accel(_rng)
509
+
510
+ return inner
511
+
512
+ sample_random_level = make_sample_random_level()
513
+
514
+ sample_random_levels = make_vmapped_filtered_level_sampler(
515
+ sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
516
+ )
517
+
518
+ def generate_world():
519
+ raise NotImplementedError
520
+ pass
521
+
522
+ def generate_eval_world(rng, env_params, static_env_params, level_idx):
523
+ # jax.random.split(jax.random.PRNGKey(101), num_levels), env_params, static_env_params, jnp.arange(num_levels)
524
+
525
+ raise NotImplementedError
526
+
527
+ _, eval_static_env_params = generate_params_from_config(
528
+ config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
529
+ )
530
+ eval_env = make_env(eval_static_env_params)
531
+ ued_params = generate_ued_params_from_config(config)
532
+
533
+ mutate_world = make_mutate_env(static_env_params, env_params, ued_params)
534
+
535
+ def make_render_fn(static_env_params):
536
+ render_fn_inner = make_render_pixels(env_params, static_env_params)
537
+ render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
538
+ return render_fn
539
+
540
+ render_fn = make_render_fn(static_env_params)
541
+ render_fn_eval = make_render_fn(eval_static_env_params)
542
+ if config["EVAL_ON_SAMPLED"]:
543
+ NUM_EVAL_DR_LEVELS = 200
544
+ key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
545
+ DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)
546
+
547
+ # And the level sampler
548
+ level_sampler = LevelSampler(
549
+ capacity=config["level_buffer_capacity"],
550
+ replay_prob=config["replay_prob"],
551
+ staleness_coeff=config["staleness_coeff"],
552
+ minimum_fill_ratio=config["minimum_fill_ratio"],
553
+ prioritization=config["prioritization"],
554
+ prioritization_params={"temperature": config["temperature"], "k": config["topk_k"]},
555
+ duplicate_check=config["buffer_duplicate_check"],
556
+ )
557
+
558
+ @jax.jit
559
+ def create_train_state(rng) -> TrainState:
560
+ # Creates the train state
561
+ def linear_schedule(count):
562
+ frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / (
563
+ config["num_updates"] * config["outer_rollout_steps"]
564
+ )
565
+ return config["lr"] * frac
566
+
567
+ rng, _rng = jax.random.split(rng)
568
+ init_state = jax.tree.map(lambda x: x[0], sample_random_levels(_rng, 1))
569
+
570
+ rng, _rng = jax.random.split(rng)
571
+ obs, _ = env.reset_to_level(_rng, init_state, env_params)
572
+ ns = config["num_steps"] * config["outer_rollout_steps"]
573
+ obs = jax.tree.map(
574
+ lambda x: jnp.repeat(jnp.repeat(x[None, ...], config["num_train_envs"], axis=0)[None, ...], ns, axis=0),
575
+ obs,
576
+ )
577
+ init_x = (obs, jnp.zeros((ns, config["num_train_envs"]), dtype=jnp.bool_))
578
+ network = make_network_from_config(env, env_params, config)
579
+ rng, _rng = jax.random.split(rng)
580
+ network_params = network.init(_rng, ScannedRNN.initialize_carry(config["num_train_envs"]), init_x)
581
+
582
+ if config["anneal_lr"]:
583
+ tx = optax.chain(
584
+ optax.clip_by_global_norm(config["max_grad_norm"]),
585
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
586
+ )
587
+ else:
588
+ tx = optax.chain(
589
+ optax.clip_by_global_norm(config["max_grad_norm"]),
590
+ optax.adam(config["lr"], eps=1e-5),
591
+ )
592
+
593
+ pholder_level = jax.tree.map(lambda x: x[0], sample_random_levels(jax.random.PRNGKey(0), 1))
594
+ sampler = level_sampler.initialize(pholder_level, {"max_return": -jnp.inf})
595
+ pholder_level_batch = jax.tree_util.tree_map(
596
+ lambda x: jnp.array([x]).repeat(config["num_train_envs"], axis=0), pholder_level
597
+ )
598
+ pholder_rollout_batch = (
599
+ jax.tree.map(
600
+ lambda x: jnp.repeat(
601
+ jnp.expand_dims(x, 0), repeats=config["num_steps"] * config["outer_rollout_steps"], axis=0
602
+ ),
603
+ init_state,
604
+ ),
605
+ init_x[1][:, 0],
606
+ )
607
+
608
+ pholder_level_batch_scores = jnp.zeros((config["num_train_envs"],), dtype=jnp.float32)
609
+ train_state = TrainState.create(
610
+ apply_fn=network.apply,
611
+ params=network_params,
612
+ tx=tx,
613
+ sampler=sampler,
614
+ update_state=0,
615
+ num_dr_updates=0,
616
+ num_replay_updates=0,
617
+ num_mutation_updates=0,
618
+ dr_last_level_batch_scores=pholder_level_batch_scores,
619
+ replay_last_level_batch_scores=pholder_level_batch_scores,
620
+ mutation_last_level_batch_scores=pholder_level_batch_scores,
621
+ dr_last_level_batch=pholder_level_batch,
622
+ replay_last_level_batch=pholder_level_batch,
623
+ mutation_last_level_batch=pholder_level_batch,
624
+ dr_last_rollout_batch=pholder_rollout_batch,
625
+ replay_last_rollout_batch=pholder_rollout_batch,
626
+ mutation_last_rollout_batch=pholder_rollout_batch,
627
+ )
628
+
629
+ if config["load_from_checkpoint"] != None:
630
+ print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
631
+ train_state = load_train_state_from_wandb_artifact_path(
632
+ train_state,
633
+ config["load_from_checkpoint"],
634
+ load_only_params=config["load_only_params"],
635
+ legacy=config["load_legacy_checkpoint"],
636
+ )
637
+ return train_state
638
+
639
+ all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
640
+ eval_group_indices = get_eval_level_groups(config["eval_levels"])
641
+
642
+ @jax.jit
643
+ def train_step(carry: Tuple[chex.PRNGKey, TrainState], _):
644
+ """
645
+ This is the main training loop. It basically calls either `on_new_levels`, `on_replay_levels`, or `on_mutate_levels` at every step.
646
+ """
647
+
648
+ def on_new_levels(rng: chex.PRNGKey, train_state: TrainState):
649
+ """
650
+ Samples new (randomly-generated) levels and evaluates the policy on these. It also then adds the levels to the level buffer if they have high-enough scores.
651
+ The agent is updated on these trajectories iff `config["exploratory_grad_updates"]` is True.
652
+ """
653
+ sampler = train_state.sampler
654
+
655
+ # Reset
656
+ rng, rng_levels, rng_reset = jax.random.split(rng, 3)
657
+ new_levels = sample_random_levels(rng_levels, config["num_train_envs"])
658
+ init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
659
+ jax.random.split(rng_reset, config["num_train_envs"]), new_levels, env_params
660
+ )
661
+ init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
662
+ # Rollout
663
+ (
664
+ (rng, train_state, new_hstate, last_obs, last_env_state),
665
+ (
666
+ obs,
667
+ actions,
668
+ rewards,
669
+ dones,
670
+ log_probs,
671
+ values,
672
+ info,
673
+ advantages,
674
+ targets,
675
+ losses,
676
+ grads,
677
+ rollout_states,
678
+ ),
679
+ ) = sample_trajectories_and_learn(
680
+ env,
681
+ env_params,
682
+ config,
683
+ rng,
684
+ train_state,
685
+ init_hstate,
686
+ init_obs,
687
+ init_env_state,
688
+ update_grad=config["exploratory_grad_updates"],
689
+ return_states=True,
690
+ )
691
+ max_returns = compute_max_returns(dones, rewards)
692
+ scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
693
+ sampler, _ = level_sampler.insert_batch(sampler, new_levels, scores, {"max_return": max_returns})
694
+ rng, _rng = jax.random.split(rng)
695
+ metrics = {
696
+ "update_state": UpdateState.DR,
697
+ } | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, new_levels)
698
+
699
+ train_state = train_state.replace(
700
+ sampler=sampler,
701
+ update_state=UpdateState.DR,
702
+ num_dr_updates=train_state.num_dr_updates + 1,
703
+ dr_last_level_batch=new_levels,
704
+ dr_last_level_batch_scores=scores,
705
+ dr_last_rollout_batch=jax.tree.map(
706
+ lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
707
+ ),
708
+ )
709
+ return (rng, train_state), metrics
710
+
711
+ def on_replay_levels(rng: chex.PRNGKey, train_state: TrainState):
712
+ """
713
+ This samples levels from the level buffer, and updates the policy on them.
714
+ """
715
+ sampler = train_state.sampler
716
+
717
+ # Collect trajectories on replay levels
718
+ rng, rng_levels, rng_reset = jax.random.split(rng, 3)
719
+ sampler, (level_inds, levels) = level_sampler.sample_replay_levels(
720
+ sampler, rng_levels, config["num_train_envs"]
721
+ )
722
+ init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
723
+ jax.random.split(rng_reset, config["num_train_envs"]), levels, env_params
724
+ )
725
+ init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
726
+ (
727
+ (rng, train_state, new_hstate, last_obs, last_env_state),
728
+ (
729
+ obs,
730
+ actions,
731
+ rewards,
732
+ dones,
733
+ log_probs,
734
+ values,
735
+ info,
736
+ advantages,
737
+ targets,
738
+ losses,
739
+ grads,
740
+ rollout_states,
741
+ ),
742
+ ) = sample_trajectories_and_learn(
743
+ env,
744
+ env_params,
745
+ config,
746
+ rng,
747
+ train_state,
748
+ init_hstate,
749
+ init_obs,
750
+ init_env_state,
751
+ update_grad=True,
752
+ return_states=True,
753
+ )
754
+
755
+ max_returns = jnp.maximum(
756
+ level_sampler.get_levels_extra(sampler, level_inds)["max_return"], compute_max_returns(dones, rewards)
757
+ )
758
+ scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
759
+ sampler = level_sampler.update_batch(sampler, level_inds, scores, {"max_return": max_returns})
760
+
761
+ rng, _rng = jax.random.split(rng)
762
+ metrics = {
763
+ "update_state": UpdateState.REPLAY,
764
+ } | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, levels)
765
+ train_state = train_state.replace(
766
+ sampler=sampler,
767
+ update_state=UpdateState.REPLAY,
768
+ num_replay_updates=train_state.num_replay_updates + 1,
769
+ replay_last_level_batch=levels,
770
+ replay_last_level_batch_scores=scores,
771
+ replay_last_rollout_batch=jax.tree.map(
772
+ lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
773
+ ),
774
+ )
775
+ return (rng, train_state), metrics
776
+
777
+ def on_mutate_levels(rng: chex.PRNGKey, train_state: TrainState):
778
+ """
779
+ This mutates the previous batch of replay levels and potentially adds them to the level buffer.
780
+ This also updates the policy iff `config["exploratory_grad_updates"]` is True.
781
+ """
782
+
783
+ sampler = train_state.sampler
784
+ rng, rng_mutate, rng_reset = jax.random.split(rng, 3)
785
+
786
+ # mutate
787
+ parent_levels = train_state.replay_last_level_batch
788
+ child_levels = jax.vmap(mutate_world, (0, 0, None))(
789
+ jax.random.split(rng_mutate, config["num_train_envs"]), parent_levels, config["num_edits"]
790
+ )
791
+ init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
792
+ jax.random.split(rng_reset, config["num_train_envs"]), child_levels, env_params
793
+ )
794
+
795
+ init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
796
+ # rollout
797
+ (
798
+ (rng, train_state, new_hstate, last_obs, last_env_state),
799
+ (
800
+ obs,
801
+ actions,
802
+ rewards,
803
+ dones,
804
+ log_probs,
805
+ values,
806
+ info,
807
+ advantages,
808
+ targets,
809
+ losses,
810
+ grads,
811
+ rollout_states,
812
+ ),
813
+ ) = sample_trajectories_and_learn(
814
+ env,
815
+ env_params,
816
+ config,
817
+ rng,
818
+ train_state,
819
+ init_hstate,
820
+ init_obs,
821
+ init_env_state,
822
+ update_grad=config["exploratory_grad_updates"],
823
+ return_states=True,
824
+ )
825
+
826
+ max_returns = compute_max_returns(dones, rewards)
827
+ scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
828
+ sampler, _ = level_sampler.insert_batch(sampler, child_levels, scores, {"max_return": max_returns})
829
+
830
+ rng, _rng = jax.random.split(rng)
831
+ metrics = {"update_state": UpdateState.MUTATE,} | get_all_metrics(
832
+ _rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, child_levels
833
+ )
834
+
835
+ train_state = train_state.replace(
836
+ sampler=sampler,
837
+ update_state=UpdateState.DR,
838
+ num_mutation_updates=train_state.num_mutation_updates + 1,
839
+ mutation_last_level_batch=child_levels,
840
+ mutation_last_level_batch_scores=scores,
841
+ mutation_last_rollout_batch=jax.tree.map(
842
+ lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
843
+ ),
844
+ )
845
+ return (rng, train_state), metrics
846
+
847
+ rng, train_state = carry
848
+ rng, rng_replay = jax.random.split(rng)
849
+
850
+ # The train step makes a decision on which branch to take, either on_new, on_replay or on_mutate.
851
+ # on_mutate is only called if the replay branch has been taken before (as it uses `train_state.update_state`).
852
+ branches = [
853
+ on_new_levels,
854
+ on_replay_levels,
855
+ ]
856
+ if config["use_accel"]:
857
+ s = train_state.update_state
858
+ branch = (1 - s) * level_sampler.sample_replay_decision(train_state.sampler, rng_replay) + 2 * s
859
+ branches.append(on_mutate_levels)
860
+ else:
861
+ branch = level_sampler.sample_replay_decision(train_state.sampler, rng_replay).astype(int)
862
+
863
+ return jax.lax.switch(branch, branches, rng, train_state)
864
+
865
+ @partial(jax.jit, static_argnums=(2,))
866
+ def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
867
+ """
868
+ This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
869
+ It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
870
+ """
871
+ num_levels = config["num_eval_levels"]
872
+ return general_eval(
873
+ rng,
874
+ eval_env,
875
+ env_params,
876
+ train_state,
877
+ all_eval_levels,
878
+ env_params.max_timesteps,
879
+ num_levels,
880
+ keep_states=keep_states,
881
+ return_trajectories=True,
882
+ )
883
+
884
+ @partial(jax.jit, static_argnums=(2,))
885
+ def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
886
+ return general_eval(
887
+ rng,
888
+ env,
889
+ env_params,
890
+ train_state,
891
+ DR_EVAL_LEVELS,
892
+ env_params.max_timesteps,
893
+ NUM_EVAL_DR_LEVELS,
894
+ keep_states=keep_states,
895
+ )
896
+
897
+ @jax.jit
898
+ def train_and_eval_step(runner_state, _):
899
+ """
900
+ This function runs the train_step for a certain number of iterations, and then evaluates the policy.
901
+ It returns the updated train state, and a dictionary of metrics.
902
+ """
903
+ # Train
904
+ (rng, train_state), metrics = jax.lax.scan(train_step, runner_state, None, config["eval_freq"])
905
+
906
+ # Eval
907
+ metrics["update_count"] = (
908
+ train_state.num_dr_updates + train_state.num_replay_updates + train_state.num_mutation_updates
909
+ )
910
+
911
+ vid_frequency = get_video_frequency(config, metrics["update_count"])
912
+ should_log_videos = metrics["update_count"] % vid_frequency == 0
913
+
914
+ def _compute_eval_learnability(dones, rewards, infos):
915
+ @jax.vmap
916
+ def _single(d, r, i):
917
+ learn, num_eps, num_succ = compute_learnability(config, d, r, i, config["num_eval_levels"])
918
+
919
+ return num_eps, num_succ.squeeze(-1)
920
+
921
+ num_eps, num_succ = _single(dones, rewards, infos)
922
+ num_eps, num_succ = num_eps.sum(axis=0), num_succ.sum(axis=0)
923
+ success_rate = num_succ / jnp.maximum(1, num_eps)
924
+
925
+ return success_rate * (1 - success_rate)
926
+
927
+ @jax.jit
928
+ def _get_eval(rng):
929
+ metrics = {}
930
+ rng, rng_eval = jax.random.split(rng)
931
+ (states, cum_rewards, done_idx, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(
932
+ eval, (0, None)
933
+ )(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)
934
+
935
+ # learnability here of the holdout set:
936
+ eval_learn = _compute_eval_learnability(eval_dones, eval_rewards, eval_infos)
937
+ # Collect Metrics
938
+ eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
939
+ eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
940
+ 1, eval_dones.sum(axis=1)
941
+ )
942
+ eval_solves = eval_solves.mean(axis=0)
943
+ metrics["eval_returns"] = eval_returns
944
+ metrics["eval_ep_lengths"] = episode_lengths.mean(axis=0)
945
+ metrics["eval_learn"] = eval_learn
946
+ metrics["eval_solves"] = eval_solves
947
+
948
+ metrics["eval_get_max_eplen"] = (episode_lengths == env_params.max_timesteps).mean(axis=0)
949
+ metrics["episode_return_bigger_than_negative"] = (cum_rewards > -0.4).mean(axis=0)
950
+
951
+ if config["EVAL_ON_SAMPLED"]:
952
+ states_dr, cum_rewards_dr, done_idx_dr, episode_lengths_dr, infos_dr = jax.vmap(
953
+ eval_on_dr_levels, (0, None)
954
+ )(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)
955
+
956
+ eval_dr_returns = cum_rewards_dr.mean(axis=0).mean()
957
+ eval_dr_eplen = episode_lengths_dr.mean(axis=0).mean()
958
+
959
+ my_eval_dones = infos_dr["returned_episode"]
960
+ eval_dr_solves = (infos_dr["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
961
+ 1, my_eval_dones.sum(axis=1)
962
+ )
963
+
964
+ metrics["eval_dr_returns"] = eval_dr_returns
965
+ metrics["eval_dr_eplen"] = eval_dr_eplen
966
+ metrics["eval_dr_solve_rates"] = eval_dr_solves
967
+ return metrics, states, episode_lengths, cum_rewards
968
+
969
+ @jax.jit
970
+ def _get_videos(rng, states, episode_lengths, cum_rewards):
971
+ metrics = {"log_videos": True}
972
+
973
+ # just grab the first run
974
+ states, episode_lengths = jax.tree_util.tree_map(
975
+ lambda x: x[0], (states, episode_lengths)
976
+ ) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
977
+ # And one attempt
978
+ states = jax.tree_util.tree_map(lambda x: x[:, :], states)
979
+ episode_lengths = episode_lengths[:]
980
+ images = jax.vmap(jax.vmap(render_fn_eval))(
981
+ states.env_state.env_state.env_state
982
+ ) # (num_steps, num_eval_levels, ...)
983
+ frames = images.transpose(
984
+ 0, 1, 4, 2, 3
985
+ ) # WandB expects color channel before image dimensions when dealing with animations for some reason
986
+
987
+ @jax.jit
988
+ def _get_video(rollout_batch):
989
+ states = rollout_batch[0]
990
+ images = jax.vmap(render_fn)(states) # dimensions are (steps, x, y, 3)
991
+ return (
992
+ # jax.tree.map(lambda x: x[:].transpose(0, 2, 1, 3)[:, ::-1], images).transpose(0, 3, 1, 2),
993
+ images.transpose(0, 3, 1, 2),
994
+ # images.transpose(0, 1, 4, 2, 3),
995
+ rollout_batch[1][:].argmax(),
996
+ )
997
+
998
+ # rollouts
999
+ metrics["dr_rollout"], metrics["dr_ep_len"] = _get_video(train_state.dr_last_rollout_batch)
1000
+ metrics["replay_rollout"], metrics["replay_ep_len"] = _get_video(train_state.replay_last_rollout_batch)
1001
+ metrics["mutation_rollout"], metrics["mutation_ep_len"] = _get_video(
1002
+ train_state.mutation_last_rollout_batch
1003
+ )
1004
+
1005
+ metrics["eval_animation"] = (frames, episode_lengths)
1006
+
1007
+ metrics["eval_returns_video"] = cum_rewards[0]
1008
+ metrics["eval_len_video"] = episode_lengths
1009
+
1010
+ # Eval on sampled
1011
+
1012
+ return metrics
1013
+
1014
+ @jax.jit
1015
+ def _get_dummy_videos(rng, states, episode_lengths, cum_rewards):
1016
+ n_eval = config["num_eval_levels"]
1017
+ nsteps = env_params.max_timesteps
1018
+ nsteps2 = config["outer_rollout_steps"] * config["num_steps"]
1019
+ img_size = (
1020
+ env.static_env_params.screen_dim[0] // env.static_env_params.downscale,
1021
+ env.static_env_params.screen_dim[1] // env.static_env_params.downscale,
1022
+ )
1023
+ return {
1024
+ "log_videos": False,
1025
+ "dr_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
1026
+ "dr_ep_len": jnp.zeros((), jnp.int32),
1027
+ "replay_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
1028
+ "replay_ep_len": jnp.zeros((), jnp.int32),
1029
+ "mutation_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
1030
+ "mutation_ep_len": jnp.zeros((), jnp.int32),
1031
+ # "eval_returns": jnp.zeros((n_eval,), jnp.float32),
1032
+ # "eval_solves": jnp.zeros((n_eval,), jnp.float32),
1033
+ # "eval_learn": jnp.zeros((n_eval,), jnp.float32),
1034
+ # "eval_ep_lengths": jnp.zeros((n_eval,), jnp.int32),
1035
+ "eval_animation": (
1036
+ jnp.zeros((nsteps, n_eval, 3, *img_size), jnp.float32),
1037
+ jnp.zeros((n_eval,), jnp.int32),
1038
+ ),
1039
+ "eval_returns_video": jnp.zeros((n_eval,), jnp.float32),
1040
+ "eval_len_video": jnp.zeros((n_eval,), jnp.int32),
1041
+ }
1042
+
1043
+ rng, rng_eval, rng_vid = jax.random.split(rng, 3)
1044
+
1045
+ metrics_eval, states, episode_lengths, cum_rewards = _get_eval(rng_eval)
1046
+ metrics = {
1047
+ **metrics,
1048
+ **metrics_eval,
1049
+ **jax.lax.cond(
1050
+ should_log_videos, _get_videos, _get_dummy_videos, rng_vid, states, episode_lengths, cum_rewards
1051
+ ),
1052
+ }
1053
+ max_num_images = 8
1054
+
1055
+ top_regret_ones = max_num_images // 2
1056
+ bot_regret_ones = max_num_images - top_regret_ones
1057
+
1058
+ @jax.jit
1059
+ def get_values(level_batch, scores):
1060
+ args = jnp.argsort(scores) # low scores are at the start, high scores are at the end
1061
+
1062
+ low_scores = args[:bot_regret_ones]
1063
+ high_scores = args[-top_regret_ones:]
1064
+
1065
+ low_levels = jax.tree.map(lambda x: x[low_scores], level_batch)
1066
+ high_levels = jax.tree.map(lambda x: x[high_scores], level_batch)
1067
+
1068
+ low_scores = scores[low_scores]
1069
+ high_scores = scores[high_scores]
1070
+ # now concatenate:
1071
+ return jax.vmap(render_fn)(
1072
+ jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), low_levels, high_levels)
1073
+ ), jnp.concatenate([low_scores, high_scores], axis=0)
1074
+
1075
+ metrics["dr_levels"], metrics["dr_scores"] = get_values(
1076
+ train_state.dr_last_level_batch, train_state.dr_last_level_batch_scores
1077
+ )
1078
+ metrics["replay_levels"], metrics["replay_scores"] = get_values(
1079
+ train_state.replay_last_level_batch, train_state.replay_last_level_batch_scores
1080
+ )
1081
+ metrics["mutation_levels"], metrics["mutation_scores"] = get_values(
1082
+ train_state.mutation_last_level_batch, train_state.mutation_last_level_batch_scores
1083
+ )
1084
+
1085
+ def _t(i):
1086
+ return jax.lax.select(i == 0, config["num_steps"], i)
1087
+
1088
+ metrics["dr_ep_len"] = _t(train_state.dr_last_rollout_batch[1][:].argmax())
1089
+ metrics["replay_ep_len"] = _t(train_state.replay_last_rollout_batch[1][:].argmax())
1090
+ metrics["mutation_ep_len"] = _t(train_state.mutation_last_rollout_batch[1][:].argmax())
1091
+
1092
+ highest_scoring_level = level_sampler.get_levels(train_state.sampler, train_state.sampler["scores"].argmax())
1093
+ highest_weighted_level = level_sampler.get_levels(
1094
+ train_state.sampler, level_sampler.level_weights(train_state.sampler).argmax()
1095
+ )
1096
+
1097
+ metrics["highest_scoring_level"] = render_fn(highest_scoring_level)
1098
+ metrics["highest_weighted_level"] = render_fn(highest_weighted_level)
1099
+
1100
+ # log_eval(metrics, train_state_to_log_dict(runner_state[1], level_sampler))
1101
+ jax.debug.callback(log_eval, metrics, train_state_to_log_dict(runner_state[1], level_sampler))
1102
+ return (rng, train_state), {"update_count": metrics["update_count"]}
1103
+
1104
+ def log_checkpoint(update_count, train_state):
1105
+ if config["save_path"] is not None and config["checkpoint_save_freq"] > 1:
1106
+ steps = (
1107
+ int(update_count)
1108
+ * int(config["num_train_envs"])
1109
+ * int(config["num_steps"])
1110
+ * int(config["outer_rollout_steps"])
1111
+ )
1112
+ # save_params_to_wandb(train_state.params, steps, config)
1113
+ save_model_to_wandb(train_state, steps, config)
1114
+
1115
+ def train_eval_and_checkpoint_step(runner_state, _):
1116
+ runner_state, metrics = jax.lax.scan(
1117
+ train_and_eval_step, runner_state, xs=jnp.arange(config["checkpoint_save_freq"] // config["eval_freq"])
1118
+ )
1119
+ jax.debug.callback(log_checkpoint, metrics["update_count"][-1], runner_state[1])
1120
+ return runner_state, metrics
1121
+
1122
+ # Set up the train states
1123
+ rng = jax.random.PRNGKey(config["seed"])
1124
+ rng_init, rng_train = jax.random.split(rng)
1125
+
1126
+ train_state = create_train_state(rng_init)
1127
+ runner_state = (rng_train, train_state)
1128
+
1129
+ runner_state, metrics = jax.lax.scan(
1130
+ train_eval_and_checkpoint_step,
1131
+ runner_state,
1132
+ xs=jnp.arange((config["num_updates"]) // (config["checkpoint_save_freq"])),
1133
+ )
1134
+
1135
+ if config["save_path"] is not None:
1136
+ # save_params_to_wandb(runner_state[1].params, config["total_timesteps"], config)
1137
+ save_model_to_wandb(runner_state[1], config["total_timesteps"], config, is_final=True)
1138
+
1139
+ return runner_state[1]
1140
+
1141
+
1142
+ if __name__ == "__main__":
1143
+ main()
Kinetix/experiments/ppo.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hydra
3
+ from omegaconf import OmegaConf
4
+
5
+ from kinetix.environment.ued.ued import (
6
+ make_reset_train_function_with_list_of_levels,
7
+ make_reset_train_function_with_mutations,
8
+ )
9
+ from kinetix.render.renderer_pixels import make_render_pixels
10
+ from kinetix.util.config import (
11
+ get_video_frequency,
12
+ init_wandb,
13
+ normalise_config,
14
+ generate_params_from_config,
15
+ )
16
+
17
+ os.environ["WANDB_DISABLE_SERVICE"] = "True"
18
+
19
+
20
+ import sys
21
+ from typing import Any, NamedTuple
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+ import optax
27
+ from flax.training.train_state import TrainState
28
+
29
+ from kinetix.models import make_network_from_config
30
+ from kinetix.util.learning import general_eval, get_eval_levels
31
+ from flax.serialization import to_state_dict
32
+
33
+ import wandb
34
+ from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name
35
+ from kinetix.environment.wrappers import (
36
+ AutoReplayWrapper,
37
+ AutoResetWrapper,
38
+ BatchEnvWrapper,
39
+ DenseRewardWrapper,
40
+ LogWrapper,
41
+ UnderspecifiedToGymnaxWrapper,
42
+ )
43
+ from kinetix.models.actor_critic import ScannedRNN
44
+ from kinetix.util.saving import (
45
+ load_train_state_from_wandb_artifact_path,
46
+ save_model_to_wandb,
47
+ )
48
+
49
+
50
+ class Transition(NamedTuple):
51
+ done: jnp.ndarray
52
+ action: jnp.ndarray
53
+ value: jnp.ndarray
54
+ reward: jnp.ndarray
55
+ log_prob: jnp.ndarray
56
+ obs: Any
57
+ info: jnp.ndarray
58
+
59
+
60
+ def make_train(config, env_params, static_env_params):
61
+ config["num_updates"] = config["total_timesteps"] // config["num_steps"] // config["num_train_envs"]
62
+ config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"]
63
+
64
+ env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
65
+
66
+ if config["train_level_mode"] == "list":
67
+ reset_func = make_reset_train_function_with_list_of_levels(
68
+ config, config["train_levels_list"], static_env_params, is_loading_train_levels=True
69
+ )
70
+ elif config["train_level_mode"] == "random":
71
+ reset_func = make_reset_train_function_with_mutations(
72
+ env.physics_engine, env_params, env.static_env_params, config
73
+ )
74
+ else:
75
+ raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
76
+
77
+ env = UnderspecifiedToGymnaxWrapper(AutoResetWrapper(env, reset_func))
78
+
79
+ eval_env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
80
+ eval_env = UnderspecifiedToGymnaxWrapper(AutoReplayWrapper(eval_env))
81
+
82
+ env = DenseRewardWrapper(env)
83
+ env = LogWrapper(env)
84
+ env = BatchEnvWrapper(env, num_envs=config["num_train_envs"])
85
+
86
+ eval_env_nonbatch = LogWrapper(DenseRewardWrapper(eval_env))
87
+
88
+ def linear_schedule(count):
89
+ frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / config["num_updates"]
90
+ return config["lr"] * frac
91
+
92
+ def linear_warmup_cosine_decay_schedule(count):
93
+ frac = (count // (config["num_minibatches"] * config["update_epochs"])) / config[
94
+ "num_updates"
95
+ ] # between 0 and 1
96
+ delta = config["peak_lr"] - config["initial_lr"]
97
+ frac_diff_max = 1.0 - config["warmup_frac"]
98
+ frac_cosine = (frac - config["warmup_frac"]) / frac_diff_max
99
+
100
+ return jax.lax.select(
101
+ frac < config["warmup_frac"],
102
+ config["initial_lr"] + delta * frac / config["warmup_frac"],
103
+ config["peak_lr"] * jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * ((frac_cosine) % 1.0)))),
104
+ )
105
+
106
+ def train(rng):
107
+ # INIT NETWORK
108
+ network = make_network_from_config(env, env_params, config)
109
+ rng, _rng = jax.random.split(rng)
110
+ obsv, env_state = env.reset(_rng, env_params)
111
+ dones = jnp.zeros((config["num_train_envs"]), dtype=jnp.bool_)
112
+ rng, _rng = jax.random.split(rng)
113
+ init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
114
+ init_x = jax.tree.map(lambda x: x[None, ...], (obsv, dones))
115
+ network_params = network.init(_rng, init_hstate, init_x)
116
+
117
+ param_count = sum(x.size for x in jax.tree_util.tree_leaves(network_params))
118
+ obs_size = sum(x.size for x in jax.tree_util.tree_leaves(obsv)) // config["num_train_envs"]
119
+
120
+ print("Number of parameters", param_count, "size of obs: ", obs_size)
121
+ if config["anneal_lr"]:
122
+ tx = optax.chain(
123
+ optax.clip_by_global_norm(config["max_grad_norm"]),
124
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
125
+ )
126
+ elif config["warmup_lr"]:
127
+ tx = optax.chain(
128
+ optax.clip_by_global_norm(config["max_grad_norm"]),
129
+ optax.adamw(learning_rate=linear_warmup_cosine_decay_schedule, eps=1e-5),
130
+ )
131
+ else:
132
+ tx = optax.chain(
133
+ optax.clip_by_global_norm(config["max_grad_norm"]),
134
+ optax.adam(config["lr"], eps=1e-5),
135
+ )
136
+ train_state = TrainState.create(
137
+ apply_fn=network.apply,
138
+ params=network_params,
139
+ tx=tx,
140
+ )
141
+ if config["load_from_checkpoint"] != None:
142
+ print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
143
+ train_state = load_train_state_from_wandb_artifact_path(
144
+ train_state, config["load_from_checkpoint"], load_only_params=config["load_only_params"]
145
+ )
146
+ # INIT ENV
147
+ rng, _rng = jax.random.split(rng)
148
+ obsv, env_state = env.reset(_rng, env_params)
149
+ init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
150
+ render_static_env_params = env.static_env_params.replace(downscale=1)
151
+ pixel_renderer = jax.jit(make_render_pixels(env_params, render_static_env_params))
152
+ pixel_render_fn = lambda x: pixel_renderer(x) / 255.0
153
+ eval_levels = get_eval_levels(config["eval_levels"], env.static_env_params)
154
+
155
+ def _vmapped_eval_step(runner_state, rng):
156
+ def _single_eval_step(rng):
157
+ return general_eval(
158
+ rng,
159
+ eval_env_nonbatch,
160
+ env_params,
161
+ runner_state[0],
162
+ eval_levels,
163
+ env_params.max_timesteps,
164
+ config["num_eval_levels"],
165
+ keep_states=True,
166
+ return_trajectories=True,
167
+ )
168
+
169
+ (states, returns, done_idxs, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(
170
+ _single_eval_step
171
+ )(jax.random.split(rng, config["eval_num_attempts"]))
172
+ eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
173
+ 1, eval_dones.sum(axis=1)
174
+ )
175
+ states_to_plot = jax.tree.map(lambda x: x[0], states)
176
+ # obs = jax.vmap(jax.vmap(pixel_render_fn))(states_to_plot.env_state.env_state.env_state)
177
+
178
+ return (
179
+ states_to_plot,
180
+ done_idxs[0],
181
+ returns[0],
182
+ returns.mean(axis=0),
183
+ episode_lengths.mean(axis=0),
184
+ eval_solves.mean(axis=0),
185
+ )
186
+
187
+ # TRAIN LOOP
188
+ def _update_step(runner_state, unused):
189
+ # COLLECT TRAJECTORIES
190
+ def _env_step(runner_state, unused):
191
+ (
192
+ train_state,
193
+ env_state,
194
+ last_obs,
195
+ last_done,
196
+ hstate,
197
+ rng,
198
+ update_step,
199
+ ) = runner_state
200
+
201
+ # SELECT ACTION
202
+ rng, _rng = jax.random.split(rng)
203
+ ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :])
204
+ hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
205
+ action = pi.sample(seed=_rng)
206
+ log_prob = pi.log_prob(action)
207
+ value, action, log_prob = (
208
+ value.squeeze(0),
209
+ action.squeeze(0),
210
+ log_prob.squeeze(0),
211
+ )
212
+
213
+ # STEP ENV
214
+ rng, _rng = jax.random.split(rng)
215
+ obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
216
+ transition = Transition(last_done, action, value, reward, log_prob, last_obs, info)
217
+ runner_state = (
218
+ train_state,
219
+ env_state,
220
+ obsv,
221
+ done,
222
+ hstate,
223
+ rng,
224
+ update_step,
225
+ )
226
+ return runner_state, transition
227
+
228
+ initial_hstate = runner_state[-3]
229
+ runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"])
230
+
231
+ # CALCULATE ADVANTAGE
232
+ (
233
+ train_state,
234
+ env_state,
235
+ last_obs,
236
+ last_done,
237
+ hstate,
238
+ rng,
239
+ update_step,
240
+ ) = runner_state
241
+ ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :])
242
+ _, _, last_val = network.apply(train_state.params, hstate, ac_in)
243
+ last_val = last_val.squeeze(0)
244
+
245
+ def _calculate_gae(traj_batch, last_val, last_done):
246
+ def _get_advantages(carry, transition):
247
+ gae, next_value, next_done = carry
248
+ done, value, reward = (
249
+ transition.done,
250
+ transition.value,
251
+ transition.reward,
252
+ )
253
+ delta = reward + config["gamma"] * next_value * (1 - next_done) - value
254
+ gae = delta + config["gamma"] * config["gae_lambda"] * (1 - next_done) * gae
255
+ return (gae, value, done), gae
256
+
257
+ _, advantages = jax.lax.scan(
258
+ _get_advantages,
259
+ (jnp.zeros_like(last_val), last_val, last_done),
260
+ traj_batch,
261
+ reverse=True,
262
+ unroll=16,
263
+ )
264
+ return advantages, advantages + traj_batch.value
265
+
266
+ advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
267
+
268
+ # UPDATE NETWORK
269
+ def _update_epoch(update_state, unused):
270
+ def _update_minbatch(train_state, batch_info):
271
+ init_hstate, traj_batch, advantages, targets = batch_info
272
+
273
+ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
274
+ # RERUN NETWORK
275
+ _, pi, value = network.apply(params, init_hstate[0], (traj_batch.obs, traj_batch.done))
276
+ log_prob = pi.log_prob(traj_batch.action)
277
+
278
+ # CALCULATE VALUE LOSS
279
+ value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
280
+ -config["clip_eps"], config["clip_eps"]
281
+ )
282
+ value_losses = jnp.square(value - targets)
283
+ value_losses_clipped = jnp.square(value_pred_clipped - targets)
284
+ value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
285
+
286
+ # CALCULATE ACTOR LOSS
287
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
288
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
289
+ loss_actor1 = ratio * gae
290
+ loss_actor2 = (
291
+ jnp.clip(
292
+ ratio,
293
+ 1.0 - config["clip_eps"],
294
+ 1.0 + config["clip_eps"],
295
+ )
296
+ * gae
297
+ )
298
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
299
+ loss_actor = loss_actor.mean()
300
+ entropy = pi.entropy().mean()
301
+
302
+ total_loss = loss_actor + config["vf_coef"] * value_loss - config["ent_coef"] * entropy
303
+ return total_loss, (value_loss, loss_actor, entropy)
304
+
305
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
306
+ total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
307
+ train_state = train_state.apply_gradients(grads=grads)
308
+ return train_state, total_loss
309
+
310
+ (
311
+ train_state,
312
+ init_hstate,
313
+ traj_batch,
314
+ advantages,
315
+ targets,
316
+ rng,
317
+ ) = update_state
318
+ rng, _rng = jax.random.split(rng)
319
+ permutation = jax.random.permutation(_rng, config["num_train_envs"])
320
+ batch = (init_hstate, traj_batch, advantages, targets)
321
+
322
+ shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch)
323
+
324
+ minibatches = jax.tree_util.tree_map(
325
+ lambda x: jnp.swapaxes(
326
+ jnp.reshape(
327
+ x,
328
+ [x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]),
329
+ ),
330
+ 1,
331
+ 0,
332
+ ),
333
+ shuffled_batch,
334
+ )
335
+
336
+ train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
337
+ update_state = (
338
+ train_state,
339
+ init_hstate,
340
+ traj_batch,
341
+ advantages,
342
+ targets,
343
+ rng,
344
+ )
345
+ return update_state, total_loss
346
+
347
+ init_hstate = initial_hstate[None, :] # TBH
348
+ update_state = (
349
+ train_state,
350
+ init_hstate,
351
+ traj_batch,
352
+ advantages,
353
+ targets,
354
+ rng,
355
+ )
356
+ update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"])
357
+ train_state = update_state[0]
358
+ metric = jax.tree.map(
359
+ lambda x: (x * traj_batch.info["returned_episode"]).sum() / traj_batch.info["returned_episode"].sum(),
360
+ traj_batch.info,
361
+ )
362
+ rng = update_state[-1]
363
+
364
+ if config["use_wandb"]:
365
+ vid_frequency = get_video_frequency(config, update_step)
366
+ rng, _rng = jax.random.split(rng)
367
+ to_log_videos = _vmapped_eval_step(runner_state, _rng)
368
+ should_log_videos = update_step % vid_frequency == 0
369
+ first = jax.lax.cond(
370
+ should_log_videos,
371
+ lambda: jax.vmap(jax.vmap(pixel_render_fn))(to_log_videos[0].env_state.env_state.env_state),
372
+ lambda: (
373
+ jnp.zeros(
374
+ (
375
+ env_params.max_timesteps,
376
+ config["num_eval_levels"],
377
+ *PixelObservations(env_params, render_static_env_params)
378
+ .observation_space(env_params)
379
+ .shape,
380
+ )
381
+ )
382
+ ),
383
+ )
384
+ to_log_videos = (first, should_log_videos, *to_log_videos[1:])
385
+
386
+ def callback(metric, raw_info, loss_info, update_step, to_log_videos):
387
+ to_log = {}
388
+ to_log["timing/num_updates"] = update_step
389
+ to_log["timing/num_env_steps"] = update_step * config["num_steps"] * config["num_train_envs"]
390
+ (
391
+ obs_vid,
392
+ should_log_videos,
393
+ idx_vid,
394
+ eval_return_vid,
395
+ eval_return_mean,
396
+ eval_eplen_mean,
397
+ eval_solverate_mean,
398
+ ) = to_log_videos
399
+ to_log["eval/mean_eval_return"] = eval_return_mean.mean()
400
+ to_log["eval/mean_eval_eplen"] = eval_eplen_mean.mean()
401
+ for i, eval_name in enumerate(config["eval_levels"]):
402
+ return_on_video = eval_return_vid[i]
403
+ to_log[f"eval_video/return_{eval_name}"] = return_on_video
404
+ to_log[f"eval_video/len_{eval_name}"] = idx_vid[i]
405
+ to_log[f"eval_avg/return_{eval_name}"] = eval_return_mean[i]
406
+ to_log[f"eval_avg/solve_rate_{eval_name}"] = eval_solverate_mean[i]
407
+
408
+ if should_log_videos:
409
+ for i, eval_name in enumerate(config["eval_levels"]):
410
+ obs_to_use = obs_vid[: idx_vid[i], i]
411
+ obs_to_use = np.asarray(obs_to_use).transpose(0, 3, 2, 1)[:, :, ::-1, :]
412
+ to_log[f"media/eval_video_{eval_name}"] = wandb.Video((obs_to_use * 255).astype(np.uint8))
413
+
414
+ wandb.log(to_log)
415
+
416
+ jax.debug.callback(callback, metric, traj_batch.info, loss_info, update_step, to_log_videos)
417
+
418
+ runner_state = (
419
+ train_state,
420
+ env_state,
421
+ last_obs,
422
+ last_done,
423
+ hstate,
424
+ rng,
425
+ update_step + 1,
426
+ )
427
+ return runner_state, metric
428
+
429
+ rng, _rng = jax.random.split(rng)
430
+ runner_state = (
431
+ train_state,
432
+ env_state,
433
+ obsv,
434
+ jnp.zeros((config["num_train_envs"]), dtype=bool),
435
+ init_hstate,
436
+ _rng,
437
+ 0,
438
+ )
439
+ runner_state, metric = jax.lax.scan(_update_step, runner_state, None, config["num_updates"])
440
+ return {"runner_state": runner_state, "metric": metric}
441
+
442
+ return train
443
+
444
+
445
+ @hydra.main(version_base=None, config_path="../configs", config_name="ppo")
446
+ def main(config):
447
+ config = normalise_config(OmegaConf.to_container(config), "PPO")
448
+ env_params, static_env_params = generate_params_from_config(config)
449
+ config["env_params"] = to_state_dict(env_params)
450
+ config["static_env_params"] = to_state_dict(static_env_params)
451
+
452
+ if config["use_wandb"]:
453
+ run = init_wandb(config, "PPO")
454
+
455
+ rng = jax.random.PRNGKey(config["seed"])
456
+ rng, _rng = jax.random.split(rng)
457
+ train_jit = jax.jit(make_train(config, env_params, static_env_params))
458
+
459
+ out = train_jit(_rng)
460
+
461
+ if config["use_wandb"]:
462
+ if config["save_policy"]:
463
+ train_state = jax.tree.map(lambda x: x, out["runner_state"][0])
464
+ save_model_to_wandb(train_state, config["total_timesteps"], config)
465
+
466
+
467
+ if __name__ == "__main__":
468
+ main()
Kinetix/experiments/sfl.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on PureJaxRL Implementation of PPO
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import time
8
+ import typing
9
+ from functools import partial
10
+ from typing import NamedTuple
11
+
12
+ import chex
13
+ import hydra
14
+ import jax
15
+ import jax.experimental
16
+ import jax.numpy as jnp
17
+ import matplotlib.pyplot as plt
18
+ import numpy as np
19
+ import optax
20
+ from flax.training.train_state import TrainState
21
+ from kinetix.environment.ued.ued import make_reset_train_function_with_mutations, make_vmapped_filtered_level_sampler
22
+ from kinetix.environment.ued.ued import (
23
+ make_reset_train_function_with_list_of_levels,
24
+ make_reset_train_function_with_mutations,
25
+ )
26
+ from kinetix.util.config import (
27
+ generate_ued_params_from_config,
28
+ init_wandb,
29
+ normalise_config,
30
+ generate_params_from_config,
31
+ get_eval_level_groups,
32
+ )
33
+ from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv
34
+ from omegaconf import OmegaConf
35
+ from PIL import Image
36
+ from flax.serialization import to_state_dict
37
+
38
+ import wandb
39
+ from kinetix.environment.env import make_kinetix_env_from_name
40
+ from kinetix.environment.wrappers import (
41
+ AutoReplayWrapper,
42
+ DenseRewardWrapper,
43
+ LogWrapper,
44
+ UnderspecifiedToGymnaxWrapper,
45
+ )
46
+ from kinetix.models import make_network_from_config
47
+ from kinetix.models.actor_critic import ScannedRNN
48
+ from kinetix.render.renderer_pixels import make_render_pixels
49
+ from kinetix.util.learning import general_eval, get_eval_levels
50
+ from kinetix.util.saving import (
51
+ load_train_state_from_wandb_artifact_path,
52
+ save_model_to_wandb,
53
+ )
54
+
55
+ sys.path.append("ued")
56
+ from flax.traverse_util import flatten_dict, unflatten_dict
57
+ from safetensors.flax import load_file, save_file
58
+
59
+
60
+ def save_params(params: typing.Dict, filename: typing.Union[str, os.PathLike]) -> None:
61
+ flattened_dict = flatten_dict(params, sep=",")
62
+ save_file(flattened_dict, filename)
63
+
64
+
65
+ def load_params(filename: typing.Union[str, os.PathLike]) -> typing.Dict:
66
+ flattened_dict = load_file(filename)
67
+ return unflatten_dict(flattened_dict, sep=",")
68
+
69
+
70
+ class Transition(NamedTuple):
71
+ global_done: jnp.ndarray
72
+ done: jnp.ndarray
73
+ action: jnp.ndarray
74
+ value: jnp.ndarray
75
+ reward: jnp.ndarray
76
+ log_prob: jnp.ndarray
77
+ obs: jnp.ndarray
78
+ info: jnp.ndarray
79
+
80
+
81
+ class RolloutBatch(NamedTuple):
82
+ obs: jnp.ndarray
83
+ actions: jnp.ndarray
84
+ rewards: jnp.ndarray
85
+ dones: jnp.ndarray
86
+ log_probs: jnp.ndarray
87
+ values: jnp.ndarray
88
+ targets: jnp.ndarray
89
+ advantages: jnp.ndarray
90
+ # carry: jnp.ndarray
91
+ mask: jnp.ndarray
92
+
93
+
94
+ def evaluate_rnn(
95
+ rng: chex.PRNGKey,
96
+ env: UnderspecifiedEnv,
97
+ env_params: EnvParams,
98
+ train_state: TrainState,
99
+ init_hstate: chex.ArrayTree,
100
+ init_obs: Observation,
101
+ init_env_state: EnvState,
102
+ max_episode_length: int,
103
+ keep_states=True,
104
+ ) -> tuple[chex.Array, chex.Array, chex.Array]:
105
+ """This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths)
106
+
107
+ Args:
108
+ rng (chex.PRNGKey):
109
+ env (UnderspecifiedEnv):
110
+ env_params (EnvParams):
111
+ train_state (TrainState):
112
+ init_hstate (chex.ArrayTree): Shape (num_levels, )
113
+ init_obs (Observation): Shape (num_levels, )
114
+ init_env_state (EnvState): Shape (num_levels, )
115
+ max_episode_length (int):
116
+
117
+ Returns:
118
+ Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,)
119
+ """
120
+ num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0]
121
+
122
+ def step(carry, _):
123
+ rng, hstate, obs, state, done, mask, episode_length = carry
124
+ rng, rng_action, rng_step = jax.random.split(rng, 3)
125
+
126
+ x = jax.tree.map(lambda x: x[None, ...], (obs, done))
127
+ hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x)
128
+ action = pi.sample(seed=rng_action).squeeze(0)
129
+
130
+ obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
131
+ jax.random.split(rng_step, num_levels), state, action, env_params
132
+ )
133
+
134
+ next_mask = mask & ~done
135
+ episode_length += mask
136
+
137
+ if keep_states:
138
+ return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, info)
139
+ else:
140
+ return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, info)
141
+
142
+ (_, _, _, _, _, _, episode_lengths), (states, rewards, infos) = jax.lax.scan(
143
+ step,
144
+ (
145
+ rng,
146
+ init_hstate,
147
+ init_obs,
148
+ init_env_state,
149
+ jnp.zeros(num_levels, dtype=bool),
150
+ jnp.ones(num_levels, dtype=bool),
151
+ jnp.zeros(num_levels, dtype=jnp.int32),
152
+ ),
153
+ None,
154
+ length=max_episode_length,
155
+ )
156
+
157
+ return states, rewards, episode_lengths, infos
158
+
159
+
160
+ @hydra.main(version_base=None, config_path="../configs", config_name="sfl")
161
+ def main(config):
162
+ time_start = time.time()
163
+ config = OmegaConf.to_container(config)
164
+ config = normalise_config(config, "SFL" if config["ued"]["sampled_envs_ratio"] > 0 else "SFL-DR")
165
+ env_params, static_env_params = generate_params_from_config(config)
166
+ config["env_params"] = to_state_dict(env_params)
167
+ config["static_env_params"] = to_state_dict(static_env_params)
168
+ run = init_wandb(config, "SFL")
169
+
170
+ rng = jax.random.PRNGKey(config["seed"])
171
+
172
+ config["num_envs_from_sampled"] = int(config["num_train_envs"] * config["sampled_envs_ratio"])
173
+ config["num_envs_to_generate"] = int(config["num_train_envs"] * (1 - config["sampled_envs_ratio"]))
174
+ assert (config["num_envs_from_sampled"] + config["num_envs_to_generate"]) == config["num_train_envs"]
175
+
176
+ def make_env(static_env_params):
177
+ env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
178
+ env = AutoReplayWrapper(env)
179
+ env = UnderspecifiedToGymnaxWrapper(env)
180
+ env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
181
+ env = LogWrapper(env)
182
+ return env
183
+
184
+ env = make_env(static_env_params)
185
+
186
+ if config["train_level_mode"] == "list":
187
+ sample_random_level = make_reset_train_function_with_list_of_levels(
188
+ config, config["train_levels"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
189
+ )
190
+ elif config["train_level_mode"] == "random":
191
+ sample_random_level = make_reset_train_function_with_mutations(
192
+ env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
193
+ )
194
+ else:
195
+ raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
196
+
197
+ sample_random_levels = make_vmapped_filtered_level_sampler(
198
+ sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
199
+ )
200
+ _, eval_static_env_params = generate_params_from_config(
201
+ config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
202
+ )
203
+ eval_env = make_env(eval_static_env_params)
204
+ ued_params = generate_ued_params_from_config(config)
205
+
206
+ def make_render_fn(static_env_params):
207
+ render_fn_inner = make_render_pixels(env_params, static_env_params)
208
+ render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
209
+ return render_fn
210
+
211
+ render_fn = make_render_fn(static_env_params)
212
+ render_fn_eval = make_render_fn(eval_static_env_params)
213
+
214
+ NUM_EVAL_DR_LEVELS = 200
215
+ key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
216
+ DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)
217
+
218
+ print("Hello here num steps is ", config["num_steps"])
219
+ print("CONFIG is ", config)
220
+
221
+ config["total_timesteps"] = config["num_updates"] * config["num_steps"] * config["num_train_envs"]
222
+ config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"]
223
+ config["clip_eps"] = config["clip_eps"]
224
+
225
+ config["env_name"] = config["env_name"]
226
+ network = make_network_from_config(env, env_params, config)
227
+
228
+ def linear_schedule(count):
229
+ count = count // (config["num_minibatches"] * config["update_epochs"])
230
+ frac = 1.0 - count / config["num_updates"]
231
+ return config["lr"] * frac
232
+
233
+ # INIT NETWORK
234
+ rng, _rng = jax.random.split(rng)
235
+ train_envs = 32 # To not run out of memory, the initial sample size does not matter.
236
+ obs, _ = env.reset_to_level(rng, sample_random_level(rng), env_params)
237
+ obs = jax.tree.map(
238
+ lambda x: jnp.repeat(jnp.repeat(x[None, ...], train_envs, axis=0)[None, ...], 256, axis=0),
239
+ obs,
240
+ )
241
+ init_x = (obs, jnp.zeros((256, train_envs)))
242
+ init_hstate = ScannedRNN.initialize_carry(train_envs)
243
+ network_params = network.init(_rng, init_hstate, init_x)
244
+ if config["anneal_lr"]:
245
+ tx = optax.chain(
246
+ optax.clip_by_global_norm(config["max_grad_norm"]),
247
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
248
+ )
249
+ else:
250
+ tx = optax.chain(
251
+ optax.clip_by_global_norm(config["max_grad_norm"]),
252
+ optax.adam(config["lr"], eps=1e-5),
253
+ )
254
+ train_state = TrainState.create(
255
+ apply_fn=network.apply,
256
+ params=network_params,
257
+ tx=tx,
258
+ )
259
+ if config["load_from_checkpoint"] != None:
260
+ print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
261
+ train_state = load_train_state_from_wandb_artifact_path(
262
+ train_state,
263
+ config["load_from_checkpoint"],
264
+ load_only_params=config["load_only_params"],
265
+ legacy=config["load_legacy_checkpoint"],
266
+ )
267
+
268
+ rng, _rng = jax.random.split(rng)
269
+
270
+ # INIT ENV
271
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
272
+ rng_reset = jax.random.split(_rng, config["num_train_envs"])
273
+
274
+ new_levels = sample_random_levels(_rng2, config["num_train_envs"])
275
+ obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
276
+
277
+ start_state = env_state
278
+ init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
279
+
280
+ @jax.jit
281
+ def log_buffer_learnability(rng, train_state, instances):
282
+ BATCH_SIZE = config["num_to_save"]
283
+ BATCH_ACTORS = BATCH_SIZE
284
+
285
+ def _batch_step(unused, rng):
286
+ def _env_step(runner_state, unused):
287
+ env_state, start_state, last_obs, last_done, hstate, rng = runner_state
288
+
289
+ # SELECT ACTION
290
+ rng, _rng = jax.random.split(rng)
291
+ obs_batch = last_obs
292
+ ac_in = (
293
+ jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
294
+ last_done[np.newaxis, :],
295
+ )
296
+ hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
297
+ action = pi.sample(seed=_rng).squeeze()
298
+ log_prob = pi.log_prob(action)
299
+ env_act = action
300
+
301
+ # STEP ENV
302
+ rng, _rng = jax.random.split(rng)
303
+ rng_step = jax.random.split(_rng, config["num_to_save"])
304
+ obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
305
+ rng_step, env_state, env_act, env_params
306
+ )
307
+ done_batch = done
308
+
309
+ transition = Transition(
310
+ done,
311
+ last_done,
312
+ action.squeeze(),
313
+ value.squeeze(),
314
+ reward,
315
+ log_prob.squeeze(),
316
+ obs_batch,
317
+ info,
318
+ )
319
+ runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
320
+ return runner_state, transition
321
+
322
+ @partial(jax.vmap, in_axes=(None, 1, 1, 1))
323
+ @partial(jax.jit, static_argnums=(0,))
324
+ def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
325
+ idxs = jnp.arange(max_steps)
326
+
327
+ @partial(jax.vmap, in_axes=(0, 0))
328
+ def __ep_outcomes(start_idx, end_idx):
329
+ mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
330
+ r = jnp.sum(returns * mask)
331
+ goal_r = info["GoalR"] # (returns > 0) * 1.0
332
+ success = jnp.sum(goal_r * mask)
333
+ l = end_idx - start_idx
334
+ return r, success, l
335
+
336
+ done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
337
+ mask_done = jnp.where(done_idxs == max_steps, 0, 1)
338
+ ep_return, success, length = __ep_outcomes(
339
+ jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
340
+ )
341
+
342
+ return {
343
+ "ep_return": ep_return.mean(where=mask_done),
344
+ "num_episodes": mask_done.sum(),
345
+ "success_rate": success.mean(where=mask_done),
346
+ "ep_len": length.mean(where=mask_done),
347
+ }
348
+
349
+ # sample envs
350
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
351
+ rng_reset = jax.random.split(_rng, config["num_to_save"])
352
+ rng_levels = jax.random.split(_rng2, config["num_to_save"])
353
+ # obsv, env_state = jax.vmap(sample_random_level, in_axes=(0,))(reset_rng)
354
+ # new_levels = jax.vmap(sample_random_level)(rng_levels)
355
+ obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, instances, env_params)
356
+ # env_instances = new_levels
357
+ init_hstate = ScannedRNN.initialize_carry(
358
+ BATCH_ACTORS,
359
+ )
360
+
361
+ runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
362
+ runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
363
+ done_by_env = traj_batch.done.reshape((-1, config["num_to_save"]))
364
+ reward_by_env = traj_batch.reward.reshape((-1, config["num_to_save"]))
365
+ # info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
366
+ o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
367
+ success_by_env = o["success_rate"].reshape((1, config["num_to_save"]))
368
+ learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
369
+ return None, (learnability_by_env, success_by_env.sum(axis=0))
370
+
371
+ rngs = jax.random.split(rng, 1)
372
+ _, (learnability, success_by_env) = jax.lax.scan(_batch_step, None, rngs, 1)
373
+ return learnability[0], success_by_env[0]
374
+
375
+ num_eval_levels = len(config["eval_levels"])
376
+ all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
377
+
378
+ eval_group_indices = get_eval_level_groups(config["eval_levels"])
379
+ print("group indices", eval_group_indices)
380
+
381
+ @jax.jit
382
+ def get_learnability_set(rng, network_params):
383
+
384
+ BATCH_ACTORS = config["batch_size"]
385
+
386
+ def _batch_step(unused, rng):
387
+ def _env_step(runner_state, unused):
388
+ env_state, start_state, last_obs, last_done, hstate, rng = runner_state
389
+
390
+ # SELECT ACTION
391
+ rng, _rng = jax.random.split(rng)
392
+ obs_batch = last_obs
393
+ ac_in = (
394
+ jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
395
+ last_done[np.newaxis, :],
396
+ )
397
+ hstate, pi, value = network.apply(network_params, hstate, ac_in)
398
+ action = pi.sample(seed=_rng).squeeze()
399
+ log_prob = pi.log_prob(action)
400
+ env_act = action
401
+
402
+ # STEP ENV
403
+ rng, _rng = jax.random.split(rng)
404
+ rng_step = jax.random.split(_rng, config["batch_size"])
405
+ obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
406
+ rng_step, env_state, env_act, env_params
407
+ )
408
+ done_batch = done
409
+
410
+ transition = Transition(
411
+ done,
412
+ last_done,
413
+ action.squeeze(),
414
+ value.squeeze(),
415
+ reward,
416
+ log_prob.squeeze(),
417
+ obs_batch,
418
+ info,
419
+ )
420
+ runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
421
+ return runner_state, transition
422
+
423
+ @partial(jax.vmap, in_axes=(None, 1, 1, 1))
424
+ @partial(jax.jit, static_argnums=(0,))
425
+ def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
426
+ idxs = jnp.arange(max_steps)
427
+
428
+ @partial(jax.vmap, in_axes=(0, 0))
429
+ def __ep_outcomes(start_idx, end_idx):
430
+ mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
431
+ r = jnp.sum(returns * mask)
432
+ goal_r = info["GoalR"] # (returns > 0) * 1.0
433
+ success = jnp.sum(goal_r * mask)
434
+ l = end_idx - start_idx
435
+ return r, success, l
436
+
437
+ done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
438
+ mask_done = jnp.where(done_idxs == max_steps, 0, 1)
439
+ ep_return, success, length = __ep_outcomes(
440
+ jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
441
+ )
442
+
443
+ return {
444
+ "ep_return": ep_return.mean(where=mask_done),
445
+ "num_episodes": mask_done.sum(),
446
+ "success_rate": success.mean(where=mask_done),
447
+ "ep_len": length.mean(where=mask_done),
448
+ }
449
+
450
+ # sample envs
451
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
452
+ rng_reset = jax.random.split(_rng, config["batch_size"])
453
+ new_levels = sample_random_levels(_rng2, config["batch_size"])
454
+ obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
455
+ env_instances = new_levels
456
+ init_hstate = ScannedRNN.initialize_carry(
457
+ BATCH_ACTORS,
458
+ )
459
+
460
+ runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
461
+ runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
462
+ done_by_env = traj_batch.done.reshape((-1, config["batch_size"]))
463
+ reward_by_env = traj_batch.reward.reshape((-1, config["batch_size"]))
464
+ # info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
465
+ o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
466
+ success_by_env = o["success_rate"].reshape((1, config["batch_size"]))
467
+ learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
468
+ return None, (learnability_by_env, success_by_env.sum(axis=0), env_instances)
469
+
470
+ if config["sampled_envs_ratio"] == 0.0:
471
+ print("Not doing any rollouts because sampled_envs_ratio is 0.0")
472
+ # Here we have zero envs, so we can literally just sample random ones because there is no point.
473
+ top_instances = sample_random_levels(_rng, config["num_to_save"])
474
+ top_success = top_learn = learnability = success_rates = jnp.zeros(config["num_to_save"])
475
+ else:
476
+ rngs = jax.random.split(rng, config["num_batches"])
477
+ _, (learnability, success_rates, env_instances) = jax.lax.scan(
478
+ _batch_step, None, rngs, config["num_batches"]
479
+ )
480
+
481
+ flat_env_instances = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), env_instances)
482
+ learnability = learnability.flatten() + success_rates.flatten() * 0.001
483
+ top_1000 = jnp.argsort(learnability)[-config["num_to_save"] :]
484
+
485
+ top_1000_instances = jax.tree.map(lambda x: x.at[top_1000].get(), flat_env_instances)
486
+ top_learn, top_instances = learnability.at[top_1000].get(), top_1000_instances
487
+ top_success = success_rates.at[top_1000].get()
488
+
489
+ if config["put_eval_levels_in_buffer"]:
490
+ top_instances = jax.tree.map(
491
+ lambda all, new: jnp.concatenate([all[:-num_eval_levels], new], axis=0),
492
+ top_instances,
493
+ all_eval_levels.env_state,
494
+ )
495
+
496
+ log = {
497
+ "learnability/learnability_sampled_mean": learnability.mean(),
498
+ "learnability/learnability_sampled_median": jnp.median(learnability),
499
+ "learnability/learnability_sampled_min": learnability.min(),
500
+ "learnability/learnability_sampled_max": learnability.max(),
501
+ "learnability/learnability_selected_mean": top_learn.mean(),
502
+ "learnability/learnability_selected_median": jnp.median(top_learn),
503
+ "learnability/learnability_selected_min": top_learn.min(),
504
+ "learnability/learnability_selected_max": top_learn.max(),
505
+ "learnability/solve_rate_sampled_mean": top_success.mean(),
506
+ "learnability/solve_rate_sampled_median": jnp.median(top_success),
507
+ "learnability/solve_rate_sampled_min": top_success.min(),
508
+ "learnability/solve_rate_sampled_max": top_success.max(),
509
+ "learnability/solve_rate_selected_mean": success_rates.mean(),
510
+ "learnability/solve_rate_selected_median": jnp.median(success_rates),
511
+ "learnability/solve_rate_selected_min": success_rates.min(),
512
+ "learnability/solve_rate_selected_max": success_rates.max(),
513
+ }
514
+
515
+ return top_learn, top_instances, log
516
+
517
+ def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
518
+ """
519
+ This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
520
+ It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
521
+ """
522
+ num_levels = len(config["eval_levels"])
523
+ # eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
524
+ return general_eval(
525
+ rng,
526
+ eval_env,
527
+ env_params,
528
+ train_state,
529
+ all_eval_levels,
530
+ env_params.max_timesteps,
531
+ num_levels,
532
+ keep_states=keep_states,
533
+ return_trajectories=True,
534
+ )
535
+
536
+ def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
537
+ return general_eval(
538
+ rng,
539
+ env,
540
+ env_params,
541
+ train_state,
542
+ DR_EVAL_LEVELS,
543
+ env_params.max_timesteps,
544
+ NUM_EVAL_DR_LEVELS,
545
+ keep_states=keep_states,
546
+ )
547
+
548
+ def eval_on_top_learnable_levels(rng: chex.PRNGKey, train_state: TrainState, levels, keep_states=True):
549
+ N = 5
550
+ return general_eval(
551
+ rng,
552
+ env,
553
+ env_params,
554
+ train_state,
555
+ jax.tree.map(lambda x: x[:N], levels),
556
+ env_params.max_timesteps,
557
+ N,
558
+ keep_states=keep_states,
559
+ )
560
+
561
+ # TRAIN LOOP
562
+ def train_step(runner_state_instances, unused):
563
+ # COLLECT TRAJECTORIES
564
+ runner_state, instances = runner_state_instances
565
+ num_env_instances = instances.polygon.position.shape[0]
566
+
567
+ def _env_step(runner_state, unused):
568
+ train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
569
+
570
+ # SELECT ACTION
571
+ rng, _rng = jax.random.split(rng)
572
+ obs_batch = last_obs
573
+ ac_in = (
574
+ jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
575
+ last_done[np.newaxis, :],
576
+ )
577
+ hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
578
+ action = pi.sample(seed=_rng).squeeze()
579
+ log_prob = pi.log_prob(action)
580
+ env_act = action
581
+
582
+ # STEP ENV
583
+ rng, _rng = jax.random.split(rng)
584
+ rng_step = jax.random.split(_rng, config["num_train_envs"])
585
+ obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
586
+ rng_step, env_state, env_act, env_params
587
+ )
588
+ done_batch = done
589
+ transition = Transition(
590
+ done,
591
+ last_done,
592
+ action.squeeze(),
593
+ value.squeeze(),
594
+ reward,
595
+ log_prob.squeeze(),
596
+ obs_batch,
597
+ info,
598
+ )
599
+ runner_state = (train_state, env_state, start_state, obsv, done_batch, hstate, update_steps, rng)
600
+ return runner_state, (transition)
601
+
602
+ initial_hstate = runner_state[-3]
603
+ runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"])
604
+
605
+ # CALCULATE ADVANTAGE
606
+ train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
607
+ last_obs_batch = last_obs # batchify(last_obs, env.agents, config["num_train_envs"])
608
+ ac_in = (
609
+ jax.tree.map(lambda x: x[np.newaxis, :], last_obs_batch),
610
+ last_done[np.newaxis, :],
611
+ )
612
+ _, _, last_val = network.apply(train_state.params, hstate, ac_in)
613
+ last_val = last_val.squeeze()
614
+
615
+ def _calculate_gae(traj_batch, last_val):
616
+ def _get_advantages(gae_and_next_value, transition: Transition):
617
+ gae, next_value = gae_and_next_value
618
+ done, value, reward = (
619
+ transition.global_done,
620
+ transition.value,
621
+ transition.reward,
622
+ )
623
+ delta = reward + config["gamma"] * next_value * (1 - done) - value
624
+ gae = delta + config["gamma"] * config["gae_lambda"] * (1 - done) * gae
625
+ return (gae, value), gae
626
+
627
+ _, advantages = jax.lax.scan(
628
+ _get_advantages,
629
+ (jnp.zeros_like(last_val), last_val),
630
+ traj_batch,
631
+ reverse=True,
632
+ unroll=16,
633
+ )
634
+ return advantages, advantages + traj_batch.value
635
+
636
+ advantages, targets = _calculate_gae(traj_batch, last_val)
637
+
638
+ # UPDATE NETWORK
639
+ def _update_epoch(update_state, unused):
640
+ def _update_minbatch(train_state, batch_info):
641
+ init_hstate, traj_batch, advantages, targets = batch_info
642
+
643
+ def _loss_fn_masked(params, init_hstate, traj_batch, gae, targets):
644
+
645
+ # RERUN NETWORK
646
+ _, pi, value = network.apply(
647
+ params,
648
+ jax.tree.map(lambda x: x.transpose(), init_hstate),
649
+ (traj_batch.obs, traj_batch.done),
650
+ )
651
+ log_prob = pi.log_prob(traj_batch.action)
652
+
653
+ # CALCULATE VALUE LOSS
654
+ value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
655
+ -config["clip_eps"], config["clip_eps"]
656
+ )
657
+ value_losses = jnp.square(value - targets)
658
+ value_losses_clipped = jnp.square(value_pred_clipped - targets)
659
+ value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped)
660
+ critic_loss = config["vf_coef"] * value_loss.mean()
661
+
662
+ # CALCULATE ACTOR LOSS
663
+ logratio = log_prob - traj_batch.log_prob
664
+ ratio = jnp.exp(logratio)
665
+ # if env.do_sep_reward: gae = gae.sum(axis=-1)
666
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
667
+ loss_actor1 = ratio * gae
668
+ loss_actor2 = (
669
+ jnp.clip(
670
+ ratio,
671
+ 1.0 - config["clip_eps"],
672
+ 1.0 + config["clip_eps"],
673
+ )
674
+ * gae
675
+ )
676
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
677
+ loss_actor = loss_actor.mean()
678
+ entropy = pi.entropy().mean()
679
+
680
+ approx_kl = jax.lax.stop_gradient(((ratio - 1) - logratio).mean())
681
+ clipfrac = jax.lax.stop_gradient((jnp.abs(ratio - 1) > config["clip_eps"]).mean())
682
+
683
+ total_loss = loss_actor + critic_loss - config["ent_coef"] * entropy
684
+ return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clipfrac)
685
+
686
+ grad_fn = jax.value_and_grad(_loss_fn_masked, has_aux=True)
687
+ total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
688
+ train_state = train_state.apply_gradients(grads=grads)
689
+ return train_state, total_loss
690
+
691
+ (
692
+ train_state,
693
+ init_hstate,
694
+ traj_batch,
695
+ advantages,
696
+ targets,
697
+ rng,
698
+ ) = update_state
699
+ rng, _rng = jax.random.split(rng)
700
+
701
+ init_hstate = jax.tree.map(lambda x: jnp.reshape(x, (256, config["num_train_envs"])), init_hstate)
702
+ batch = (
703
+ init_hstate,
704
+ traj_batch,
705
+ advantages.squeeze(),
706
+ targets.squeeze(),
707
+ )
708
+ permutation = jax.random.permutation(_rng, config["num_train_envs"])
709
+
710
+ shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch)
711
+
712
+ minibatches = jax.tree_util.tree_map(
713
+ lambda x: jnp.swapaxes(
714
+ jnp.reshape(
715
+ x,
716
+ [x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]),
717
+ ),
718
+ 1,
719
+ 0,
720
+ ),
721
+ shuffled_batch,
722
+ )
723
+
724
+ train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
725
+ # total_loss = jax.tree.map(lambda x: x.mean(), total_loss)
726
+ update_state = (
727
+ train_state,
728
+ init_hstate,
729
+ traj_batch,
730
+ advantages,
731
+ targets,
732
+ rng,
733
+ )
734
+ return update_state, total_loss
735
+
736
+ # init_hstate = initial_hstate[None, :].squeeze().transpose()
737
+ init_hstate = jax.tree.map(lambda x: x[None, :].squeeze().transpose(), initial_hstate)
738
+ update_state = (
739
+ train_state,
740
+ init_hstate,
741
+ traj_batch,
742
+ advantages,
743
+ targets,
744
+ rng,
745
+ )
746
+ update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"])
747
+ train_state = update_state[0]
748
+ metric = traj_batch.info
749
+ metric = jax.tree.map(
750
+ lambda x: x.reshape((config["num_steps"], config["num_train_envs"])), # , env.num_agents
751
+ traj_batch.info,
752
+ )
753
+ rng = update_state[-1]
754
+
755
+ def callback(metric):
756
+ dones = metric["dones"]
757
+ wandb.log(
758
+ {
759
+ "episode_return": (metric["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
760
+ "episode_solved": (metric["returned_episode_solved"] * dones).sum() / jnp.maximum(1, dones.sum()),
761
+ "episode_length": (metric["returned_episode_lengths"] * dones).sum() / jnp.maximum(1, dones.sum()),
762
+ "timing/num_env_steps": int(
763
+ int(metric["update_steps"]) * int(config["num_train_envs"]) * int(config["num_steps"])
764
+ ),
765
+ "timing/num_updates": metric["update_steps"],
766
+ **metric["loss_info"],
767
+ }
768
+ )
769
+
770
+ loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
771
+ metric["loss_info"] = {
772
+ "loss/total_loss": loss_info[0],
773
+ "loss/value_loss": loss_info[1][0],
774
+ "loss/policy_loss": loss_info[1][1],
775
+ "loss/entropy_loss": loss_info[1][2],
776
+ }
777
+ metric["dones"] = traj_batch.done
778
+ metric["update_steps"] = update_steps
779
+ jax.experimental.io_callback(callback, None, metric)
780
+
781
+ # SAMPLE NEW ENVS
782
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
783
+ rng_reset = jax.random.split(_rng, config["num_envs_to_generate"])
784
+
785
+ new_levels = sample_random_levels(_rng2, config["num_envs_to_generate"])
786
+ obsv_gen, env_state_gen = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
787
+
788
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
789
+ sampled_env_instances_idxs = jax.random.randint(_rng, (config["num_envs_from_sampled"],), 0, num_env_instances)
790
+ sampled_env_instances = jax.tree.map(lambda x: x.at[sampled_env_instances_idxs].get(), instances)
791
+ myrng = jax.random.split(_rng2, config["num_envs_from_sampled"])
792
+ obsv_sampled, env_state_sampled = jax.vmap(env.reset_to_level, in_axes=(0, 0))(myrng, sampled_env_instances)
793
+
794
+ obsv = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), obsv_gen, obsv_sampled)
795
+ env_state = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), env_state_gen, env_state_sampled)
796
+
797
+ start_state = env_state
798
+ hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
799
+
800
+ update_steps = update_steps + 1
801
+ runner_state = (
802
+ train_state,
803
+ env_state,
804
+ start_state,
805
+ obsv,
806
+ jnp.zeros((config["num_train_envs"]), dtype=bool),
807
+ hstate,
808
+ update_steps,
809
+ rng,
810
+ )
811
+ return (runner_state, instances), metric
812
+
813
+ def log_buffer(learnability, levels, epoch):
814
+ num_samples = levels.polygon.position.shape[0]
815
+ states = levels
816
+ rows = 2
817
+ fig, axes = plt.subplots(rows, int(num_samples / rows), figsize=(20, 10))
818
+ axes = axes.flatten()
819
+ all_imgs = jax.vmap(render_fn)(states)
820
+ for i, ax in enumerate(axes):
821
+ # ax.imshow(train_state.plr_buffer.get_sample(i))
822
+ score = learnability[i]
823
+ ax.imshow(all_imgs[i] / 255.0)
824
+ ax.set_xticks([])
825
+ ax.set_yticks([])
826
+ ax.set_title(f"learnability: {score:.3f}")
827
+ ax.set_aspect("equal", "box")
828
+
829
+ plt.tight_layout()
830
+ fig.canvas.draw()
831
+ im = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
832
+ plt.close()
833
+ return {"maps": wandb.Image(im)}
834
+
835
+ @jax.jit
836
+ def train_and_eval_step(runner_state, eval_rng):
837
+
838
+ learnability_rng, eval_singleton_rng, eval_sampled_rng, _rng = jax.random.split(eval_rng, 4)
839
+ # TRAIN
840
+ learnabilty_scores, instances, test_metrics = get_learnability_set(learnability_rng, runner_state[0].params)
841
+
842
+ if config["log_learnability_before_after"]:
843
+ learn_scores_before, success_score_before = log_buffer_learnability(
844
+ learnability_rng, runner_state[0], instances
845
+ )
846
+
847
+ print("instance size", sum(x.size for x in jax.tree_util.tree_leaves(instances)))
848
+
849
+ runner_state_instances = (runner_state, instances)
850
+ runner_state_instances, metrics = jax.lax.scan(train_step, runner_state_instances, None, config["eval_freq"])
851
+
852
+ if config["log_learnability_before_after"]:
853
+ learn_scores_after, success_score_after = log_buffer_learnability(
854
+ learnability_rng, runner_state_instances[0][0], instances
855
+ )
856
+
857
+ # EVAL
858
+ rng, rng_eval = jax.random.split(eval_singleton_rng)
859
+ (states, cum_rewards, _, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(eval, (0, None))(
860
+ jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
861
+ )
862
+ all_eval_eplens = episode_lengths
863
+
864
+ # Collect Metrics
865
+ eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
866
+ eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
867
+ 1, eval_dones.sum(axis=1)
868
+ )
869
+ eval_solves = eval_solves.mean(axis=0)
870
+ # just grab the first run
871
+ states, episode_lengths = jax.tree_util.tree_map(
872
+ lambda x: x[0], (states, episode_lengths)
873
+ ) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
874
+ # And one attempt
875
+ states = jax.tree_util.tree_map(lambda x: x[:, :], states)
876
+ episode_lengths = episode_lengths[:]
877
+ images = jax.vmap(jax.vmap(render_fn_eval))(
878
+ states.env_state.env_state.env_state
879
+ ) # (num_steps, num_eval_levels, ...)
880
+ frames = images.transpose(
881
+ 0, 1, 4, 2, 3
882
+ ) # WandB expects color channel before image dimensions when dealing with animations for some reason
883
+
884
+ test_metrics["update_count"] = runner_state[-2]
885
+ test_metrics["eval_returns"] = eval_returns
886
+ test_metrics["eval_ep_lengths"] = episode_lengths
887
+ test_metrics["eval_animation"] = (frames, episode_lengths)
888
+
889
+ # Eval on sampled
890
+ dr_states, dr_cum_rewards, _, dr_episode_lengths, dr_infos = jax.vmap(eval_on_dr_levels, (0, None))(
891
+ jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
892
+ )
893
+
894
+ eval_dr_returns = dr_cum_rewards.mean(axis=0).mean()
895
+ eval_dr_eplen = dr_episode_lengths.mean(axis=0).mean()
896
+
897
+ test_metrics["eval/mean_eval_return_sampled"] = eval_dr_returns
898
+ my_eval_dones = dr_infos["returned_episode"]
899
+ eval_dr_solves = (dr_infos["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
900
+ 1, my_eval_dones.sum(axis=1)
901
+ )
902
+
903
+ test_metrics["eval/mean_eval_solve_rate_sampled"] = eval_dr_solves
904
+ test_metrics["eval/mean_eval_eplen_sampled"] = eval_dr_eplen
905
+
906
+ # Collect Metrics
907
+ eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
908
+
909
+ log_dict = {}
910
+
911
+ log_dict["to_remove"] = {
912
+ "eval_return": eval_returns,
913
+ "eval_solve_rate": eval_solves,
914
+ "eval_eplen": all_eval_eplens,
915
+ }
916
+
917
+ for i, name in enumerate(config["eval_levels"]):
918
+ log_dict[f"eval_avg_return/{name}"] = eval_returns[i]
919
+ log_dict[f"eval_avg_solve_rate/{name}"] = eval_solves[i]
920
+
921
+ log_dict.update({"eval/mean_eval_return": eval_returns.mean()})
922
+ log_dict.update({"eval/mean_eval_solve_rate": eval_solves.mean()})
923
+ log_dict.update({"eval/mean_eval_eplen": all_eval_eplens.mean()})
924
+
925
+ test_metrics.update(log_dict)
926
+
927
+ runner_state, _ = runner_state_instances
928
+ test_metrics["update_count"] = runner_state[-2]
929
+
930
+ top_instances = jax.tree.map(lambda x: x.at[-5:].get(), instances)
931
+
932
+ # Eval on top learnable levels
933
+ tl_states, tl_cum_rewards, _, tl_episode_lengths, tl_infos = jax.vmap(
934
+ eval_on_top_learnable_levels, (0, None, None)
935
+ )(jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0], top_instances)
936
+
937
+ # just grab the first run
938
+ states, episode_lengths = jax.tree_util.tree_map(
939
+ lambda x: x[0], (tl_states, tl_episode_lengths)
940
+ ) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
941
+ # And one attempt
942
+ states = jax.tree_util.tree_map(lambda x: x[:, :], states)
943
+ episode_lengths = episode_lengths[:]
944
+ images = jax.vmap(jax.vmap(render_fn))(
945
+ states.env_state.env_state.env_state
946
+ ) # (num_steps, num_eval_levels, ...)
947
+ frames = images.transpose(
948
+ 0, 1, 4, 2, 3
949
+ ) # WandB expects color channel before image dimensions when dealing with animations for some reason
950
+
951
+ test_metrics["top_learnable_animation"] = (frames, episode_lengths, tl_cum_rewards)
952
+
953
+ if config["log_learnability_before_after"]:
954
+
955
+ def single(x, name):
956
+ return {
957
+ f"{name}_mean": x.mean(),
958
+ f"{name}_std": x.std(),
959
+ f"{name}_min": x.min(),
960
+ f"{name}_max": x.max(),
961
+ f"{name}_median": jnp.median(x),
962
+ }
963
+
964
+ test_metrics["learnability_log_v2/"] = {
965
+ **single(learn_scores_before, "learnability_before"),
966
+ **single(learn_scores_after, "learnability_after"),
967
+ **single(success_score_before, "success_score_before"),
968
+ **single(success_score_after, "success_score_after"),
969
+ }
970
+
971
+ return runner_state, (learnabilty_scores.at[-20:].get(), top_instances), test_metrics
972
+
973
+ rng, _rng = jax.random.split(rng)
974
+ runner_state = (
975
+ train_state,
976
+ env_state,
977
+ start_state,
978
+ obsv,
979
+ jnp.zeros((config["num_train_envs"]), dtype=bool),
980
+ init_hstate,
981
+ 0,
982
+ _rng,
983
+ )
984
+
985
+ def log_eval(stats):
986
+ log_dict = {}
987
+
988
+ to_remove = stats["to_remove"]
989
+ del stats["to_remove"]
990
+
991
+ def _aggregate_per_size(values, name):
992
+ to_return = {}
993
+ for group_name, indices in eval_group_indices.items():
994
+ to_return[f"{name}_{group_name}"] = values[indices].mean()
995
+ return to_return
996
+
997
+ env_steps = stats["update_count"] * config["num_train_envs"] * config["num_steps"]
998
+ env_steps_delta = config["eval_freq"] * config["num_train_envs"] * config["num_steps"]
999
+ time_now = time.time()
1000
+ log_dict = {
1001
+ "timing/num_updates": stats["update_count"],
1002
+ "timing/num_env_steps": env_steps,
1003
+ "timing/sps": env_steps_delta / stats["time_delta"],
1004
+ "timing/sps_agg": env_steps / (time_now - time_start),
1005
+ }
1006
+ log_dict.update(_aggregate_per_size(to_remove["eval_return"], "eval_aggregate/return"))
1007
+ log_dict.update(_aggregate_per_size(to_remove["eval_solve_rate"], "eval_aggregate/solve_rate"))
1008
+
1009
+ for i in range((len(config["eval_levels"]))):
1010
+ frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
1011
+ frames = np.array(frames[:episode_length])
1012
+ log_dict.update(
1013
+ {
1014
+ f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
1015
+ frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})"
1016
+ )
1017
+ }
1018
+ )
1019
+
1020
+ for j in range(5):
1021
+ frames, episode_length, cum_rewards = (
1022
+ stats["top_learnable_animation"][0][:, j],
1023
+ stats["top_learnable_animation"][1][j],
1024
+ stats["top_learnable_animation"][2][:, j],
1025
+ ) # num attempts
1026
+ rr = "|".join([f"{r:<.2f}" for r in cum_rewards])
1027
+ frames = np.array(frames[:episode_length])
1028
+ log_dict.update(
1029
+ {
1030
+ f"media/tl_animation_{j}": wandb.Video(
1031
+ frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})\n{rr}"
1032
+ )
1033
+ }
1034
+ )
1035
+
1036
+ stats.update(log_dict)
1037
+ wandb.log(stats, step=stats["update_count"])
1038
+
1039
+ checkpoint_steps = config["checkpoint_save_freq"]
1040
+ assert config["num_updates"] % config["eval_freq"] == 0, "num_updates must be divisible by eval_freq"
1041
+
1042
+ for eval_step in range(int(config["num_updates"] // config["eval_freq"])):
1043
+ start_time = time.time()
1044
+ rng, eval_rng = jax.random.split(rng)
1045
+ runner_state, instances, metrics = train_and_eval_step(runner_state, eval_rng)
1046
+ curr_time = time.time()
1047
+ metrics.update(log_buffer(*instances, metrics["update_count"]))
1048
+ metrics["time_delta"] = curr_time - start_time
1049
+ metrics["steps_per_section"] = (config["eval_freq"] * config["num_steps"] * config["num_train_envs"]) / metrics[
1050
+ "time_delta"
1051
+ ]
1052
+ log_eval(metrics)
1053
+ if ((eval_step + 1) * config["eval_freq"]) % checkpoint_steps == 0:
1054
+ if config["save_path"] is not None:
1055
+ steps = int(metrics["update_count"]) * int(config["num_train_envs"]) * int(config["num_steps"])
1056
+ # save_params_to_wandb(runner_state[0].params, steps, config)
1057
+ save_model_to_wandb(runner_state[0], steps, config)
1058
+
1059
+ if config["save_path"] is not None:
1060
+ # save_params_to_wandb(runner_state[0].params, config["total_timesteps"], config)
1061
+ save_model_to_wandb(runner_state[0], config["total_timesteps"], config)
1062
+
1063
+
1064
+ if __name__ == "__main__":
1065
+ # with jax.disable_jit():
1066
+ # main()
1067
+ main()
Kinetix/images/bb.gif ADDED
Kinetix/images/cartpole.gif ADDED