Spaces:
Runtime error
Runtime error
Upload 190 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- Kinetix/.gitignore +194 -0
- Kinetix/.pre-commit-config.yaml +7 -0
- Kinetix/LICENSE +19 -0
- Kinetix/README.md +217 -0
- Kinetix/configs/editor.yaml +22 -0
- Kinetix/configs/env/entity.yaml +3 -0
- Kinetix/configs/env/symbolic.yaml +3 -0
- Kinetix/configs/env_size/custom.yaml +3 -0
- Kinetix/configs/env_size/l.yaml +8 -0
- Kinetix/configs/env_size/m.yaml +8 -0
- Kinetix/configs/env_size/s.yaml +8 -0
- Kinetix/configs/eval/eval_all.yaml +82 -0
- Kinetix/configs/eval/eval_auto.yaml +4 -0
- Kinetix/configs/eval/eval_general.yaml +7 -0
- Kinetix/configs/eval/l.yaml +46 -0
- Kinetix/configs/eval/m.yaml +30 -0
- Kinetix/configs/eval/mujoco.yaml +13 -0
- Kinetix/configs/eval/s.yaml +16 -0
- Kinetix/configs/eval_env_size/l.yaml +7 -0
- Kinetix/configs/eval_env_size/m.yaml +7 -0
- Kinetix/configs/eval_env_size/s.yaml +7 -0
- Kinetix/configs/learning/ppo-base.yaml +20 -0
- Kinetix/configs/learning/ppo-rnn.yaml +2 -0
- Kinetix/configs/learning/ppo-sfl.yaml +1 -0
- Kinetix/configs/learning/ppo-ued.yaml +2 -0
- Kinetix/configs/misc/misc.yaml +16 -0
- Kinetix/configs/model/model-base.yaml +4 -0
- Kinetix/configs/model/model-transformer.yaml +6 -0
- Kinetix/configs/plr.yaml +17 -0
- Kinetix/configs/ppo.yaml +20 -0
- Kinetix/configs/sfl.yaml +21 -0
- Kinetix/configs/train_levels/l.yaml +44 -0
- Kinetix/configs/train_levels/m.yaml +28 -0
- Kinetix/configs/train_levels/mujoco.yaml +11 -0
- Kinetix/configs/train_levels/random.yaml +2 -0
- Kinetix/configs/train_levels/s.yaml +14 -0
- Kinetix/configs/train_levels/train_all.yaml +80 -0
- Kinetix/configs/ued/accel.yaml +16 -0
- Kinetix/configs/ued/plr.yaml +17 -0
- Kinetix/configs/ued/sfl.yaml +9 -0
- Kinetix/docs/README.md +83 -0
- Kinetix/docs/configs.md +179 -0
- Kinetix/examples/example_premade_level_replay.py +46 -0
- Kinetix/examples/example_random_level_replay.py +51 -0
- Kinetix/experiments/plr.py +1143 -0
- Kinetix/experiments/ppo.py +468 -0
- Kinetix/experiments/sfl.py +1067 -0
- Kinetix/images/bb.gif +0 -0
- Kinetix/images/cartpole.gif +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Kinetix/images/general_2.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Kinetix/images/kinetix_logo.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
Kinetix/images/random_1.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
Kinetix/images/random_3.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
Kinetix/images/random_4.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
Kinetix/images/random_5.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
Kinetix/images/random_6.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
Kinetix/images/random_7.gif filter=lfs diff=lfs merge=lfs -text
|
Kinetix/.gitignore
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tmp/
|
2 |
+
wandb/
|
3 |
+
runs/
|
4 |
+
|
5 |
+
play_data
|
6 |
+
checkpoints
|
7 |
+
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache*
|
53 |
+
.cache_*
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
*.py,cover
|
58 |
+
.hypothesis/
|
59 |
+
.pytest_cache/
|
60 |
+
cover/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
db.sqlite3-journal
|
71 |
+
|
72 |
+
# Flask stuff:
|
73 |
+
instance/
|
74 |
+
.webassets-cache
|
75 |
+
|
76 |
+
# Scrapy stuff:
|
77 |
+
.scrapy
|
78 |
+
|
79 |
+
# Sphinx documentation
|
80 |
+
docs/_build/
|
81 |
+
|
82 |
+
# PyBuilder
|
83 |
+
.pybuilder/
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
# For a library or package, you might want to ignore these files since the code is
|
95 |
+
# intended to run in multiple environments; otherwise, check them in:
|
96 |
+
# .python-version
|
97 |
+
|
98 |
+
# pipenv
|
99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
102 |
+
# install all needed dependencies.
|
103 |
+
#Pipfile.lock
|
104 |
+
|
105 |
+
# poetry
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
107 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
108 |
+
# commonly ignored for libraries.
|
109 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
110 |
+
#poetry.lock
|
111 |
+
|
112 |
+
# pdm
|
113 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
114 |
+
#pdm.lock
|
115 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
116 |
+
# in version control.
|
117 |
+
# https://pdm.fming.dev/#use-with-ide
|
118 |
+
.pdm.toml
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
|
140 |
+
!configs/env
|
141 |
+
|
142 |
+
# Spyder project settings
|
143 |
+
.spyderproject
|
144 |
+
.spyproject
|
145 |
+
|
146 |
+
# Rope project settings
|
147 |
+
.ropeproject
|
148 |
+
|
149 |
+
# mkdocs documentation
|
150 |
+
/site
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
|
160 |
+
# pytype static type analyzer
|
161 |
+
.pytype/
|
162 |
+
|
163 |
+
# Cython debug symbols
|
164 |
+
cython_debug/
|
165 |
+
|
166 |
+
# PyCharm
|
167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
171 |
+
.idea/
|
172 |
+
texture_cache.pbz2
|
173 |
+
texture_cache*.pbz2
|
174 |
+
profile*
|
175 |
+
wandb_key
|
176 |
+
test.py
|
177 |
+
outputs
|
178 |
+
lol*
|
179 |
+
.cache-location
|
180 |
+
experiments/ppo_old.py
|
181 |
+
.bash_history
|
182 |
+
logs/
|
183 |
+
.vscode
|
184 |
+
|
185 |
+
kinetix/util/old_learning_with_mask.py
|
186 |
+
offline/datasets/*.pkl
|
187 |
+
all_sweeps
|
188 |
+
worlds/games
|
189 |
+
|
190 |
+
artifacts
|
191 |
+
log*_*
|
192 |
+
kinetix/analysis/test*.py
|
193 |
+
slurm-*.out
|
194 |
+
results/
|
Kinetix/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/psf/black
|
3 |
+
rev: 22.3.0
|
4 |
+
hooks:
|
5 |
+
- id: black
|
6 |
+
language_version: python3
|
7 |
+
args: [--line-length=120]
|
Kinetix/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2024 Michael Matthews
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in all
|
11 |
+
copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
SOFTWARE.
|
Kinetix/README.md
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="middle">
|
2 |
+
<img src="images/kinetix_logo.gif" width="500" />
|
3 |
+
</p>
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<a href= "https://pypi.org/project/jax2d/">
|
7 |
+
<img src="https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue" /></a>
|
8 |
+
<a href= "https://github.com/FLAIROx/Kinetix/blob/main/LICENSE">
|
9 |
+
<img src="https://img.shields.io/badge/License-MIT-yellow" /></a>
|
10 |
+
<a href= "https://github.com/psf/black">
|
11 |
+
<img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a>
|
12 |
+
<a href= "https://kinetix-env.github.io/">
|
13 |
+
<img src="https://img.shields.io/badge/online-editor-purple" /></a>
|
14 |
+
<a href= "https://arxiv.org/abs/2410.23208">
|
15 |
+
<img src="https://img.shields.io/badge/arxiv-2410.23208-b31b1b" /></a>
|
16 |
+
<a href= "./docs/README.md">
|
17 |
+
<img src="https://img.shields.io/badge/docs-green" /></a>
|
18 |
+
</p>
|
19 |
+
|
20 |
+
# Kinetix
|
21 |
+
|
22 |
+
Kinetix is a framework for reinforcement learning in a 2D rigid-body physics world, written entirely in [JAX](https://github.com/jax-ml/jax).
|
23 |
+
Kinetix can represent a huge array of physics-based tasks within a unified framework.
|
24 |
+
We use Kinetix to investigate the training of large, general reinforcement learning agents by procedurally generating millions of tasks for training.
|
25 |
+
You can play with Kinetix in our [online editor](https://kinetix-env.github.io/), or have a look at the JAX [physics engine](https://github.com/MichaelTMatthews/Jax2D) and [graphics library](https://github.com/FLAIROx/JaxGL) we made for Kinetix. Finally, see our [docs](./docs/README.md) for more information and more in-depth examples.
|
26 |
+
|
27 |
+
<p align="middle">
|
28 |
+
<img src="images/bb.gif" width="200" />
|
29 |
+
<img src="images/cartpole.gif" width="200" />
|
30 |
+
<img src="images/grasper.gif" width="200" />
|
31 |
+
</p>
|
32 |
+
<p align="middle">
|
33 |
+
<img src="images/hc.gif" width="200" />
|
34 |
+
<img src="images/hopper.gif" width="200" />
|
35 |
+
<img src="images/ll.gif" width="200" />
|
36 |
+
</p>
|
37 |
+
|
38 |
+
<p align="middle">
|
39 |
+
<b>The above shows specialist agents trained on their respective levels.</b>
|
40 |
+
</p>
|
41 |
+
|
42 |
+
# 📊 Paper TL; DR
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
We train a general agent on millions of procedurally generated physics tasks.
|
47 |
+
Every task has the same goal: make the <span style="color:green">green</span> and <span style="color:blue">blue</span> touch, without <span style="color:green">green</span> touching <span style="color:red">red</span>.
|
48 |
+
The agent can act through applying torque via motors and force via thrusters.
|
49 |
+
|
50 |
+
<p align="middle">
|
51 |
+
<img src="images/random_1.gif" width="200" />
|
52 |
+
<img src="images/random_5.gif" width="200" />
|
53 |
+
<img src="images/random_3.gif" width="200" />
|
54 |
+
</p>
|
55 |
+
<p align="middle">
|
56 |
+
<img src="images/random_4.gif" width="200" />
|
57 |
+
<img src="images/random_6.gif" width="200" />
|
58 |
+
<img src="images/random_7.gif" width="200" />
|
59 |
+
</p>
|
60 |
+
|
61 |
+
<p align="middle">
|
62 |
+
<b>The above shows a general agent zero-shotting unseen randomly generated levels.</b>
|
63 |
+
</p>
|
64 |
+
|
65 |
+
We then investigate the transfer capabilities of this agent to unseen handmade levels.
|
66 |
+
We find that the agent can zero-shot simple physics problems, but still struggles with harder tasks.
|
67 |
+
|
68 |
+
<p align="middle">
|
69 |
+
<img src="images/general_1.gif" width="200" />
|
70 |
+
<img src="images/general_2.gif" width="200" />
|
71 |
+
<img src="images/general_3.gif" width="200" />
|
72 |
+
</p>
|
73 |
+
<p align="middle">
|
74 |
+
<img src="images/general_4.gif" width="200" />
|
75 |
+
<img src="images/general_5.gif" width="200" />
|
76 |
+
<img src="images/general_6.gif" width="200" />
|
77 |
+
</p>
|
78 |
+
|
79 |
+
<p align="middle">
|
80 |
+
<b>The above shows a general agent zero-shotting unseen handmade levels.</b>
|
81 |
+
</p>
|
82 |
+
|
83 |
+
|
84 |
+
# 📜 Basic Usage
|
85 |
+
|
86 |
+
Kinetix follows the interfaces established in [gymnax](https://github.com/RobertTLange/gymnax) and [jaxued](https://github.com/DramaCow/jaxued):
|
87 |
+
|
88 |
+
```python
|
89 |
+
# Use default parameters
|
90 |
+
env_params = EnvParams()
|
91 |
+
static_env_params = StaticEnvParams()
|
92 |
+
ued_params = UEDParams()
|
93 |
+
|
94 |
+
# Create the environment
|
95 |
+
env = make_kinetix_env_from_args(
|
96 |
+
obs_type="pixels",
|
97 |
+
action_type="multidiscrete",
|
98 |
+
reset_type="replay",
|
99 |
+
static_env_params=static_env_params,
|
100 |
+
)
|
101 |
+
|
102 |
+
# Sample a random level
|
103 |
+
rng = jax.random.PRNGKey(0)
|
104 |
+
rng, _rng = jax.random.split(rng)
|
105 |
+
level = sample_kinetix_level(_rng, env.physics_engine, env_params, static_env_params, ued_params)
|
106 |
+
|
107 |
+
# Reset the environment state to this level
|
108 |
+
rng, _rng = jax.random.split(rng)
|
109 |
+
obs, env_state = env.reset_to_level(_rng, level, env_params)
|
110 |
+
|
111 |
+
# Take a step in the environment
|
112 |
+
rng, _rng = jax.random.split(rng)
|
113 |
+
action = env.action_space(env_params).sample(_rng)
|
114 |
+
rng, _rng = jax.random.split(rng)
|
115 |
+
obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
|
116 |
+
```
|
117 |
+
|
118 |
+
|
119 |
+
# ⬇️ Installation
|
120 |
+
To install Kinetix with a CUDA-enabled JAX backend (tested with python3.10):
|
121 |
+
```commandline
|
122 |
+
git clone https://github.com/FlairOx/Kinetix.git
|
123 |
+
cd Kinetix
|
124 |
+
pip install -e .
|
125 |
+
pre-commit install
|
126 |
+
```
|
127 |
+
|
128 |
+
# 🎯 Editor
|
129 |
+
We recommend using the [KinetixJS editor](https://kinetix-env.github.io/gallery.html?editor=true), but also provide a native (less polished) Kinetix editor.
|
130 |
+
|
131 |
+
To open this editor run the following command.
|
132 |
+
```commandline
|
133 |
+
python3 kinetix/editor.py
|
134 |
+
```
|
135 |
+
|
136 |
+
The controls in the editor are:
|
137 |
+
- Move between `edit` and `play` modes using `spacebar`
|
138 |
+
- In `edit` mode, the type of edit is shown by the icon at the top and is changed by scrolling the mouse wheel. For instance, by navigating to the rectangle editing function you can click to place a rectangle.
|
139 |
+
- You can also press the number keys to cycle between modes.
|
140 |
+
- To open handmade levels press ctrl-O and navigate to the ones in the L folder.
|
141 |
+
- **When playing a level use the arrow keys to control motors and the numeric keys (1, 2) to control thrusters.**
|
142 |
+
|
143 |
+
# 📈 Experiments
|
144 |
+
|
145 |
+
We have three primary experiment files,
|
146 |
+
1. [**SFL**](https://github.com/amacrutherford/sampling-for-learnability?tab=readme-ov-file): Training on levels with high learnability, this is how we trained our best general agents.
|
147 |
+
2. **PLR** PLR/DR/ACCEL in the [JAXUED](https://github.com/DramaCow/jaxued) style.
|
148 |
+
3. **PPO** Normal PPO in the [PureJaxRL](https://github.com/luchris429/purejaxrl/) style.
|
149 |
+
|
150 |
+
To run experiments with default parameters run any of the following:
|
151 |
+
```commandline
|
152 |
+
python3 experiments/sfl.py
|
153 |
+
python3 experiments/plr.py
|
154 |
+
python3 experiments/ppo.py
|
155 |
+
```
|
156 |
+
|
157 |
+
We use [hydra](https://hydra.cc/) for managing our configs. See the `configs/` folder for all the hydra configs that will be used by default.
|
158 |
+
If you want to run experiments with different configurations, you can either edit these configs or pass command line arguments as so:
|
159 |
+
|
160 |
+
```commandline
|
161 |
+
python3 experiments/sfl.py model.transformer_depth=8
|
162 |
+
```
|
163 |
+
|
164 |
+
These experiments use [wandb](https://wandb.ai/home) for logging by default.
|
165 |
+
|
166 |
+
## 🏋️ Training RL Agents
|
167 |
+
We provide several different ways to train RL agents, with the three most common options being, (a) [Training an agent on random levels](#training-on-random-levels), (b) [Training an agent on a single, hand-designed level](#training-on-a-single-hand-designed-level) or (c) [Training an agent on a set of hand-designed levels](#training-on-a-set-of-hand-designed-levels).
|
168 |
+
|
169 |
+
> [!WARNING]
|
170 |
+
> Kinetix has three different environment sizes, `s`, `m` and `l`. When running any of the scripts, you have to set the `env_size` option accordingly, for instance, `python3 experiments/ppo.py train_levels=random env_size=m` would train on random `m` levels.
|
171 |
+
> It will give an error if you try and load large levels into a small env size, for instance `python3 experiments/ppo.py train_levels=m env_size=s` would error.
|
172 |
+
|
173 |
+
### Training on random levels
|
174 |
+
This is the default option, but we give the explicit command for completeness
|
175 |
+
```commandline
|
176 |
+
python3 experiments/ppo.py train_levels=random
|
177 |
+
```
|
178 |
+
### Training on a single hand-designed level
|
179 |
+
|
180 |
+
> [!NOTE]
|
181 |
+
> Check the `worlds/` folder for handmade levels for each size category. By default, the loading functions require a relative path to the `worlds/` directory
|
182 |
+
|
183 |
+
```commandline
|
184 |
+
python3 experiments/ppo.py train_levels=s train_levels.train_levels_list='["s/h4_thrust_aim.json"]'
|
185 |
+
```
|
186 |
+
### Training on a set of hand-designed levels
|
187 |
+
```commandline
|
188 |
+
python3 experiments/ppo.py train_levels=s env_size=s eval_env_size=s
|
189 |
+
# python3 experiments/ppo.py train_levels=m env_size=m eval_env_size=m
|
190 |
+
# python3 experiments/ppo.py train_levels=l env_size=l eval_env_size=l
|
191 |
+
```
|
192 |
+
|
193 |
+
Or, on a custom set:
|
194 |
+
```commandline
|
195 |
+
python3 experiments/ppo.py train_levels=l eval_env_size=l env_size=l train_levels.train_levels_list='["s/h2_one_wheel_car","l/h11_obstacle_avoidance"]'
|
196 |
+
```
|
197 |
+
|
198 |
+
|
199 |
+
# 🔎 See Also
|
200 |
+
- 🌐 [Kinetix.js](https://github.com/Michael-Beukman/Kinetix.js) Kinetix reimplemented in Javascript, with a live demo [here](https://kinetix-env.github.io/gallery.html?editor=true).
|
201 |
+
- 🍎 [Jax2D](https://github.com/MichaelTMatthews/Jax2D) The physics engine we made for Kinetix.
|
202 |
+
- 👨💻 [JaxGL](https://github.com/FLAIROx/JaxGL) The graphics library we made for Kinetix.
|
203 |
+
- 📋 [Our Paper](https://arxiv.org/abs/2410.23208) for more details and empirical results.
|
204 |
+
|
205 |
+
# 📚 Citation
|
206 |
+
Please cite Kinetix it as follows:
|
207 |
+
```
|
208 |
+
@article{matthews2024kinetix,
|
209 |
+
title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks},
|
210 |
+
author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
|
211 |
+
year={2024},
|
212 |
+
eprint={2410.23208},
|
213 |
+
archivePrefix={arXiv},
|
214 |
+
primaryClass={cs.LG},
|
215 |
+
url={https://arxiv.org/abs/2410.23208},
|
216 |
+
}
|
217 |
+
```
|
Kinetix/configs/editor.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- env: entity
|
3 |
+
- env_size: l
|
4 |
+
- learning:
|
5 |
+
- ppo-base
|
6 |
+
- ppo-rnn
|
7 |
+
- misc: misc
|
8 |
+
- model:
|
9 |
+
- model-base
|
10 |
+
- model-transformer
|
11 |
+
- _self_
|
12 |
+
|
13 |
+
seed: 0
|
14 |
+
upscale: 2
|
15 |
+
downscale: 1
|
16 |
+
fps: 60
|
17 |
+
debug: true
|
18 |
+
|
19 |
+
env:
|
20 |
+
frame_skip: 1
|
21 |
+
|
22 |
+
agent_taking_actions: false
|
Kinetix/configs/env/entity.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
env_name: "Kinetix-Entity-MultiDiscrete-v1"
|
2 |
+
dense_reward_scale: 2.0
|
3 |
+
frame_skip: 2
|
Kinetix/configs/env/symbolic.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
env_name: "Kinetix-Symbolic-MultiDiscrete-v1"
|
2 |
+
dense_reward_scale: 2.0
|
3 |
+
frame_skip: 2
|
Kinetix/configs/env_size/custom.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
custom_path: worlds/l/grasp_easy.json
|
2 |
+
env_size_type: custom
|
3 |
+
env_size_name: custom
|
Kinetix/configs/env_size/l.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_polygons: 12
|
2 |
+
num_circles: 4
|
3 |
+
num_joints: 6
|
4 |
+
num_thrusters: 2
|
5 |
+
env_size_name: l
|
6 |
+
num_motor_bindings: 4
|
7 |
+
num_thruster_bindings: 2
|
8 |
+
env_size_type: predefined
|
Kinetix/configs/env_size/m.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_polygons: 6
|
2 |
+
num_circles: 3
|
3 |
+
num_joints: 2
|
4 |
+
num_thrusters: 2
|
5 |
+
env_size_name: m
|
6 |
+
num_motor_bindings: 4
|
7 |
+
num_thruster_bindings: 2
|
8 |
+
env_size_type: predefined
|
Kinetix/configs/env_size/s.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_polygons: 5
|
2 |
+
num_circles: 2
|
3 |
+
num_joints: 1
|
4 |
+
num_thrusters: 1
|
5 |
+
env_size_name: s
|
6 |
+
num_motor_bindings: 4
|
7 |
+
num_thruster_bindings: 2
|
8 |
+
env_size_type: predefined
|
Kinetix/configs/eval/eval_all.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels:
|
2 |
+
[
|
3 |
+
"s/h0_weak_thrust",
|
4 |
+
"s/h7_unicycle_left",
|
5 |
+
"s/h3_point_the_thruster",
|
6 |
+
"s/h4_thrust_aim",
|
7 |
+
"s/h1_thrust_over_ball",
|
8 |
+
"s/h5_rotate_fall",
|
9 |
+
"s/h9_explode_then_thrust_over",
|
10 |
+
"s/h6_unicycle_right",
|
11 |
+
"s/h8_unicycle_balance",
|
12 |
+
"s/h2_one_wheel_car",
|
13 |
+
|
14 |
+
"m/h0_unicycle",
|
15 |
+
"m/h1_car_left",
|
16 |
+
"m/h2_car_right",
|
17 |
+
"m/h3_car_thrust",
|
18 |
+
"m/h4_thrust_the_needle",
|
19 |
+
"m/h5_angry_birds",
|
20 |
+
"m/h6_thrust_over",
|
21 |
+
"m/h7_car_flip",
|
22 |
+
"m/h8_weird_vehicle",
|
23 |
+
"m/h9_spin_the_right_way",
|
24 |
+
"m/h10_thrust_right_easy",
|
25 |
+
"m/h11_thrust_left_easy",
|
26 |
+
"m/h12_thrustfall_left",
|
27 |
+
"m/h13_thrustfall_right",
|
28 |
+
"m/h14_thrustblock",
|
29 |
+
"m/h15_thrustshoot",
|
30 |
+
"m/h16_thrustcontrol_right",
|
31 |
+
"m/h17_thrustcontrol_left",
|
32 |
+
"m/h18_thrust_right_very_easy",
|
33 |
+
"m/h19_thrust_left_very_easy",
|
34 |
+
"m/arm_left",
|
35 |
+
"m/arm_right",
|
36 |
+
"m/arm_up",
|
37 |
+
"m/arm_hard",
|
38 |
+
|
39 |
+
"l/h0_angrybirds",
|
40 |
+
"l/h1_car_left",
|
41 |
+
"l/h2_car_ramp",
|
42 |
+
"l/h3_car_right",
|
43 |
+
"l/h4_cartpole",
|
44 |
+
"l/h5_flappy_bird",
|
45 |
+
"l/h6_lorry",
|
46 |
+
"l/h7_maze_1",
|
47 |
+
"l/h8_maze_2",
|
48 |
+
"l/h9_morph_direction",
|
49 |
+
"l/h10_morph_direction_2",
|
50 |
+
"l/h11_obstacle_avoidance",
|
51 |
+
"l/h12_platformer_1",
|
52 |
+
"l/h13_platformer_2",
|
53 |
+
"l/h14_simple_thruster",
|
54 |
+
"l/h15_swing_up",
|
55 |
+
"l/h16_thruster_goal",
|
56 |
+
"l/h17_unicycle",
|
57 |
+
"l/hard_beam_balance",
|
58 |
+
"l/hard_cartpole_thrust",
|
59 |
+
"l/hard_cartpole_wheels",
|
60 |
+
"l/hard_lunar_lander",
|
61 |
+
"l/hard_pinball",
|
62 |
+
"l/grasp_hard",
|
63 |
+
"l/grasp_easy",
|
64 |
+
"l/mjc_half_cheetah",
|
65 |
+
"l/mjc_half_cheetah_easy",
|
66 |
+
"l/mjc_hopper",
|
67 |
+
"l/mjc_hopper_easy",
|
68 |
+
"l/mjc_swimmer",
|
69 |
+
"l/mjc_walker",
|
70 |
+
"l/mjc_walker_easy",
|
71 |
+
"l/car_launch",
|
72 |
+
"l/car_swing_around",
|
73 |
+
"l/chain_lander",
|
74 |
+
"l/chain_thrust",
|
75 |
+
"l/gears",
|
76 |
+
"l/lever_puzzle",
|
77 |
+
"l/pr",
|
78 |
+
"l/rail",
|
79 |
+
]
|
80 |
+
eval_num_attempts: 10
|
81 |
+
eval_freq: 10
|
82 |
+
EVAL_ON_SAMPLED: false
|
Kinetix/configs/eval/eval_auto.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels: "auto"
|
2 |
+
eval_num_attempts: 10
|
3 |
+
eval_freq: 10
|
4 |
+
EVAL_ON_SAMPLED: false
|
Kinetix/configs/eval/eval_general.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels:
|
2 |
+
[
|
3 |
+
"easy.simple_thruster",
|
4 |
+
]
|
5 |
+
eval_num_attempts: 10
|
6 |
+
eval_freq: 10
|
7 |
+
EVAL_ON_SAMPLED: false
|
Kinetix/configs/eval/l.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels:
|
2 |
+
[
|
3 |
+
"l/h0_angrybirds",
|
4 |
+
"l/h1_car_left",
|
5 |
+
"l/h2_car_ramp",
|
6 |
+
"l/h3_car_right",
|
7 |
+
"l/h4_cartpole",
|
8 |
+
"l/h5_flappy_bird",
|
9 |
+
"l/h6_lorry",
|
10 |
+
"l/h7_maze_1",
|
11 |
+
"l/h8_maze_2",
|
12 |
+
"l/h9_morph_direction",
|
13 |
+
"l/h10_morph_direction_2",
|
14 |
+
"l/h11_obstacle_avoidance",
|
15 |
+
"l/h12_platformer_1",
|
16 |
+
"l/h13_platformer_2",
|
17 |
+
"l/h14_simple_thruster",
|
18 |
+
"l/h15_swing_up",
|
19 |
+
"l/h16_thruster_goal",
|
20 |
+
"l/h17_unicycle",
|
21 |
+
"l/hard_beam_balance",
|
22 |
+
"l/hard_cartpole_thrust",
|
23 |
+
"l/hard_cartpole_wheels",
|
24 |
+
"l/hard_lunar_lander",
|
25 |
+
"l/hard_pinball",
|
26 |
+
"l/grasp_hard",
|
27 |
+
"l/grasp_easy",
|
28 |
+
"l/mjc_half_cheetah",
|
29 |
+
"l/mjc_half_cheetah_easy",
|
30 |
+
"l/mjc_hopper",
|
31 |
+
"l/mjc_hopper_easy",
|
32 |
+
"l/mjc_swimmer",
|
33 |
+
"l/mjc_walker",
|
34 |
+
"l/mjc_walker_easy",
|
35 |
+
"l/car_launch",
|
36 |
+
"l/car_swing_around",
|
37 |
+
"l/chain_lander",
|
38 |
+
"l/chain_thrust",
|
39 |
+
"l/gears",
|
40 |
+
"l/lever_puzzle",
|
41 |
+
"l/pr",
|
42 |
+
"l/rail",
|
43 |
+
]
|
44 |
+
eval_num_attempts: 10
|
45 |
+
eval_freq: 50
|
46 |
+
EVAL_ON_SAMPLED: true
|
Kinetix/configs/eval/m.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels:
|
2 |
+
[
|
3 |
+
"m/h0_unicycle",
|
4 |
+
"m/h1_car_left",
|
5 |
+
"m/h2_car_right",
|
6 |
+
"m/h3_car_thrust",
|
7 |
+
"m/h4_thrust_the_needle",
|
8 |
+
"m/h5_angry_birds",
|
9 |
+
"m/h6_thrust_over",
|
10 |
+
"m/h7_car_flip",
|
11 |
+
"m/h8_weird_vehicle",
|
12 |
+
"m/h9_spin_the_right_way",
|
13 |
+
"m/h10_thrust_right_easy",
|
14 |
+
"m/h11_thrust_left_easy",
|
15 |
+
"m/h12_thrustfall_left",
|
16 |
+
"m/h13_thrustfall_right",
|
17 |
+
"m/h14_thrustblock",
|
18 |
+
"m/h15_thrustshoot",
|
19 |
+
"m/h16_thrustcontrol_right",
|
20 |
+
"m/h17_thrustcontrol_left",
|
21 |
+
"m/h18_thrust_right_very_easy",
|
22 |
+
"m/h19_thrust_left_very_easy",
|
23 |
+
"m/arm_left",
|
24 |
+
"m/arm_right",
|
25 |
+
"m/arm_up",
|
26 |
+
"m/arm_hard",
|
27 |
+
]
|
28 |
+
eval_num_attempts: 10
|
29 |
+
eval_freq: 50
|
30 |
+
EVAL_ON_SAMPLED: true
|
Kinetix/configs/eval/mujoco.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels:
|
2 |
+
[
|
3 |
+
"l/mjc_half_cheetah",
|
4 |
+
"l/mjc_half_cheetah_easy",
|
5 |
+
"l/mjc_hopper",
|
6 |
+
"l/mjc_hopper_easy",
|
7 |
+
"l/mjc_swimmer",
|
8 |
+
"l/mjc_walker",
|
9 |
+
"l/mjc_walker_easy",
|
10 |
+
]
|
11 |
+
eval_num_attempts: 10
|
12 |
+
eval_freq: 10
|
13 |
+
EVAL_ON_SAMPLED: false
|
Kinetix/configs/eval/s.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
eval_levels:
|
2 |
+
[
|
3 |
+
"s/h0_weak_thrust",
|
4 |
+
"s/h7_unicycle_left",
|
5 |
+
"s/h3_point_the_thruster",
|
6 |
+
"s/h4_thrust_aim",
|
7 |
+
"s/h1_thrust_over_ball",
|
8 |
+
"s/h5_rotate_fall",
|
9 |
+
"s/h9_explode_then_thrust_over",
|
10 |
+
"s/h6_unicycle_right",
|
11 |
+
"s/h8_unicycle_balance",
|
12 |
+
"s/h2_one_wheel_car",
|
13 |
+
]
|
14 |
+
eval_num_attempts: 10
|
15 |
+
eval_freq: 50
|
16 |
+
EVAL_ON_SAMPLED: true
|
Kinetix/configs/eval_env_size/l.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_polygons: 12
|
2 |
+
num_circles: 4
|
3 |
+
num_joints: 6
|
4 |
+
num_thrusters: 2
|
5 |
+
env_size_name: l
|
6 |
+
num_motor_bindings: 4
|
7 |
+
num_thruster_bindings: 2
|
Kinetix/configs/eval_env_size/m.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_polygons: 6
|
2 |
+
num_circles: 3
|
3 |
+
num_joints: 2
|
4 |
+
num_thrusters: 2
|
5 |
+
env_size_name: m
|
6 |
+
num_motor_bindings: 4
|
7 |
+
num_thruster_bindings: 2
|
Kinetix/configs/eval_env_size/s.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_polygons: 5
|
2 |
+
num_circles: 2
|
3 |
+
num_joints: 1
|
4 |
+
num_thrusters: 1
|
5 |
+
env_size_name: s
|
6 |
+
num_motor_bindings: 4
|
7 |
+
num_thruster_bindings: 2
|
Kinetix/configs/learning/ppo-base.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lr: 5e-5
|
2 |
+
peak_lr: 3e-4
|
3 |
+
initial_lr: 1e-5
|
4 |
+
warmup_frac: 0.1
|
5 |
+
max_grad_norm: 1.0
|
6 |
+
total_timesteps: 1073741824
|
7 |
+
num_train_envs: 2048
|
8 |
+
num_minibatches: 32
|
9 |
+
gamma: 0.995
|
10 |
+
update_epochs: 8
|
11 |
+
clip_eps: 0.2
|
12 |
+
gae_lambda: 0.9
|
13 |
+
ent_coef: 0.01
|
14 |
+
anneal_lr: false
|
15 |
+
warmup_lr: false
|
16 |
+
vf_coef: 0.5
|
17 |
+
permute_state_during_training: false
|
18 |
+
filter_levels: true
|
19 |
+
level_filter_n_steps: 64
|
20 |
+
level_filter_sample_ratio: 2
|
Kinetix/configs/learning/ppo-rnn.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
num_steps: 64
|
2 |
+
num_repeats: 1
|
Kinetix/configs/learning/ppo-sfl.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
num_steps: 512
|
Kinetix/configs/learning/ppo-ued.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
num_steps: 64
|
2 |
+
outer_rollout_steps: 4
|
Kinetix/configs/misc/misc.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
group: "auto"
|
2 |
+
group_auto_prefix: ""
|
3 |
+
save_path: "checkpoints/kinetix"
|
4 |
+
use_wandb: true
|
5 |
+
save_policy: true
|
6 |
+
wandb_project: "kinetix-experiments"
|
7 |
+
wandb_entity: null
|
8 |
+
wandb_mode : online
|
9 |
+
video_frequency: 10
|
10 |
+
load_from_checkpoint: null
|
11 |
+
load_only_params: true
|
12 |
+
checkpoint_save_freq: 512
|
13 |
+
checkpoint_human_numbers: false
|
14 |
+
load_legacy_checkpoint: false
|
15 |
+
load_train_levels_legacy: false
|
16 |
+
economical_saving: false
|
Kinetix/configs/model/model-base.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fc_layer_depth: 5
|
2 |
+
fc_layer_width: 128
|
3 |
+
activation: "tanh"
|
4 |
+
recurrent_model: False
|
Kinetix/configs/model/model-transformer.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_depth: 2
|
2 |
+
transformer_size: 16
|
3 |
+
transformer_encoder_size: 128
|
4 |
+
num_heads: 8
|
5 |
+
full_attention_mask: false
|
6 |
+
aggregate_mode: dummy_and_mean
|
Kinetix/configs/plr.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- env: entity
|
3 |
+
- learning:
|
4 |
+
- ppo-base
|
5 |
+
- ppo-ued
|
6 |
+
- misc: misc
|
7 |
+
- env_size: s
|
8 |
+
- eval: s
|
9 |
+
- eval_env_size: s
|
10 |
+
- ued: plr
|
11 |
+
- train_levels: random
|
12 |
+
- model:
|
13 |
+
- model-base
|
14 |
+
- model-transformer
|
15 |
+
- _self_
|
16 |
+
|
17 |
+
seed: 0
|
Kinetix/configs/ppo.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- env: entity
|
3 |
+
- env_size: s
|
4 |
+
- learning:
|
5 |
+
- ppo-base
|
6 |
+
- ppo-rnn
|
7 |
+
- misc: misc
|
8 |
+
- eval: s
|
9 |
+
- eval_env_size: s
|
10 |
+
- train_levels: random
|
11 |
+
- model:
|
12 |
+
- model-base
|
13 |
+
- model-transformer
|
14 |
+
- _self_
|
15 |
+
|
16 |
+
|
17 |
+
eval:
|
18 |
+
eval_freq: 40
|
19 |
+
|
20 |
+
seed: 0
|
Kinetix/configs/sfl.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- env: entity
|
3 |
+
- learning:
|
4 |
+
- ppo-base
|
5 |
+
- ppo-rnn
|
6 |
+
- misc: misc
|
7 |
+
- ued: sfl
|
8 |
+
- env_size: s
|
9 |
+
- eval: s
|
10 |
+
- eval_env_size: s
|
11 |
+
- train_levels: random
|
12 |
+
- model:
|
13 |
+
- model-base
|
14 |
+
- model-transformer
|
15 |
+
- _self_
|
16 |
+
|
17 |
+
eval:
|
18 |
+
eval_freq: 128
|
19 |
+
learning:
|
20 |
+
num_steps: 256
|
21 |
+
seed: 0
|
Kinetix/configs/train_levels/l.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_level_mode: list
|
2 |
+
train_levels_list:
|
3 |
+
[
|
4 |
+
"l/h0_angrybirds",
|
5 |
+
"l/h1_car_left",
|
6 |
+
"l/h2_car_ramp",
|
7 |
+
"l/h3_car_right",
|
8 |
+
"l/h4_cartpole",
|
9 |
+
"l/h5_flappy_bird",
|
10 |
+
"l/h6_lorry",
|
11 |
+
"l/h7_maze_1",
|
12 |
+
"l/h8_maze_2",
|
13 |
+
"l/h9_morph_direction",
|
14 |
+
"l/h10_morph_direction_2",
|
15 |
+
"l/h11_obstacle_avoidance",
|
16 |
+
"l/h12_platformer_1",
|
17 |
+
"l/h13_platformer_2",
|
18 |
+
"l/h14_simple_thruster",
|
19 |
+
"l/h15_swing_up",
|
20 |
+
"l/h16_thruster_goal",
|
21 |
+
"l/h17_unicycle",
|
22 |
+
"l/hard_beam_balance",
|
23 |
+
"l/hard_cartpole_thrust",
|
24 |
+
"l/hard_cartpole_wheels",
|
25 |
+
"l/hard_lunar_lander",
|
26 |
+
"l/hard_pinball",
|
27 |
+
"l/mjc_half_cheetah",
|
28 |
+
"l/mjc_half_cheetah_easy",
|
29 |
+
"l/mjc_hopper",
|
30 |
+
"l/mjc_hopper_easy",
|
31 |
+
"l/mjc_swimmer",
|
32 |
+
"l/mjc_walker",
|
33 |
+
"l/mjc_walker_easy",
|
34 |
+
"l/grasp_hard",
|
35 |
+
"l/grasp_easy",
|
36 |
+
"l/car_launch",
|
37 |
+
"l/car_swing_around",
|
38 |
+
"l/chain_lander",
|
39 |
+
"l/chain_thrust",
|
40 |
+
"l/gears",
|
41 |
+
"l/lever_puzzle",
|
42 |
+
"l/pr",
|
43 |
+
"l/rail",
|
44 |
+
]
|
Kinetix/configs/train_levels/m.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_level_mode: list
|
2 |
+
train_levels_list:
|
3 |
+
[
|
4 |
+
"m/h0_unicycle",
|
5 |
+
"m/h1_car_left",
|
6 |
+
"m/h2_car_right",
|
7 |
+
"m/h3_car_thrust",
|
8 |
+
"m/h4_thrust_the_needle",
|
9 |
+
"m/h5_angry_birds",
|
10 |
+
"m/h6_thrust_over",
|
11 |
+
"m/h7_car_flip",
|
12 |
+
"m/h8_weird_vehicle",
|
13 |
+
"m/h9_spin_the_right_way",
|
14 |
+
"m/h10_thrust_right_easy",
|
15 |
+
"m/h11_thrust_left_easy",
|
16 |
+
"m/h12_thrustfall_left",
|
17 |
+
"m/h13_thrustfall_right",
|
18 |
+
"m/h14_thrustblock",
|
19 |
+
"m/h15_thrustshoot",
|
20 |
+
"m/h16_thrustcontrol_right",
|
21 |
+
"m/h17_thrustcontrol_left",
|
22 |
+
"m/h18_thrust_right_very_easy",
|
23 |
+
"m/h19_thrust_left_very_easy",
|
24 |
+
"m/arm_left",
|
25 |
+
"m/arm_right",
|
26 |
+
"m/arm_up",
|
27 |
+
"m/arm_hard",
|
28 |
+
]
|
Kinetix/configs/train_levels/mujoco.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_level_mode: list
|
2 |
+
train_levels_list:
|
3 |
+
[
|
4 |
+
"l/mjc_half_cheetah",
|
5 |
+
"l/mjc_half_cheetah_easy",
|
6 |
+
"l/mjc_hopper",
|
7 |
+
"l/mjc_hopper_easy",
|
8 |
+
"l/mjc_swimmer",
|
9 |
+
"l/mjc_walker",
|
10 |
+
"l/mjc_walker_easy",
|
11 |
+
]
|
Kinetix/configs/train_levels/random.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
train_level_mode: random
|
2 |
+
train_level_distribution: distribution_v3
|
Kinetix/configs/train_levels/s.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_level_mode: list
|
2 |
+
train_levels_list:
|
3 |
+
[
|
4 |
+
"s/h0_weak_thrust",
|
5 |
+
"s/h7_unicycle_left",
|
6 |
+
"s/h3_point_the_thruster",
|
7 |
+
"s/h4_thrust_aim",
|
8 |
+
"s/h1_thrust_over_ball",
|
9 |
+
"s/h5_rotate_fall",
|
10 |
+
"s/h9_explode_then_thrust_over",
|
11 |
+
"s/h6_unicycle_right",
|
12 |
+
"s/h8_unicycle_balance",
|
13 |
+
"s/h2_one_wheel_car",
|
14 |
+
]
|
Kinetix/configs/train_levels/train_all.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_level_mode: list
|
2 |
+
train_levels_list:
|
3 |
+
[
|
4 |
+
"s/h0_weak_thrust",
|
5 |
+
"s/h7_unicycle_left",
|
6 |
+
"s/h3_point_the_thruster",
|
7 |
+
"s/h4_thrust_aim",
|
8 |
+
"s/h1_thrust_over_ball",
|
9 |
+
"s/h5_rotate_fall",
|
10 |
+
"s/h9_explode_then_thrust_over",
|
11 |
+
"s/h6_unicycle_right",
|
12 |
+
"s/h8_unicycle_balance",
|
13 |
+
"s/h2_one_wheel_car",
|
14 |
+
|
15 |
+
"m/h0_unicycle",
|
16 |
+
"m/h1_car_left",
|
17 |
+
"m/h2_car_right",
|
18 |
+
"m/h3_car_thrust",
|
19 |
+
"m/h4_thrust_the_needle",
|
20 |
+
"m/h5_angry_birds",
|
21 |
+
"m/h6_thrust_over",
|
22 |
+
"m/h7_car_flip",
|
23 |
+
"m/h8_weird_vehicle",
|
24 |
+
"m/h9_spin_the_right_way",
|
25 |
+
"m/h10_thrust_right_easy",
|
26 |
+
"m/h11_thrust_left_easy",
|
27 |
+
"m/h12_thrustfall_left",
|
28 |
+
"m/h13_thrustfall_right",
|
29 |
+
"m/h14_thrustblock",
|
30 |
+
"m/h15_thrustshoot",
|
31 |
+
"m/h16_thrustcontrol_right",
|
32 |
+
"m/h17_thrustcontrol_left",
|
33 |
+
"m/h18_thrust_right_very_easy",
|
34 |
+
"m/h19_thrust_left_very_easy",
|
35 |
+
"m/arm_left",
|
36 |
+
"m/arm_right",
|
37 |
+
"m/arm_up",
|
38 |
+
"m/arm_hard",
|
39 |
+
|
40 |
+
"l/h0_angrybirds",
|
41 |
+
"l/h1_car_left",
|
42 |
+
"l/h2_car_ramp",
|
43 |
+
"l/h3_car_right",
|
44 |
+
"l/h4_cartpole",
|
45 |
+
"l/h5_flappy_bird",
|
46 |
+
"l/h6_lorry",
|
47 |
+
"l/h7_maze_1",
|
48 |
+
"l/h8_maze_2",
|
49 |
+
"l/h9_morph_direction",
|
50 |
+
"l/h10_morph_direction_2",
|
51 |
+
"l/h11_obstacle_avoidance",
|
52 |
+
"l/h12_platformer_1",
|
53 |
+
"l/h13_platformer_2",
|
54 |
+
"l/h14_simple_thruster",
|
55 |
+
"l/h15_swing_up",
|
56 |
+
"l/h16_thruster_goal",
|
57 |
+
"l/h17_unicycle",
|
58 |
+
"l/hard_beam_balance",
|
59 |
+
"l/hard_cartpole_thrust",
|
60 |
+
"l/hard_cartpole_wheels",
|
61 |
+
"l/hard_lunar_lander",
|
62 |
+
"l/hard_pinball",
|
63 |
+
"l/grasp_hard",
|
64 |
+
"l/grasp_easy",
|
65 |
+
"l/mjc_half_cheetah",
|
66 |
+
"l/mjc_half_cheetah_easy",
|
67 |
+
"l/mjc_hopper",
|
68 |
+
"l/mjc_hopper_easy",
|
69 |
+
"l/mjc_swimmer",
|
70 |
+
"l/mjc_walker",
|
71 |
+
"l/mjc_walker_easy",
|
72 |
+
"l/car_launch",
|
73 |
+
"l/car_swing_around",
|
74 |
+
"l/chain_lander",
|
75 |
+
"l/chain_thrust",
|
76 |
+
"l/gears",
|
77 |
+
"l/lever_puzzle",
|
78 |
+
"l/pr",
|
79 |
+
"l/rail",
|
80 |
+
]
|
Kinetix/configs/ued/accel.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use_accel: true
|
2 |
+
exploratory_grad_updates: true
|
3 |
+
num_edits: 5
|
4 |
+
score_function: MaxMC
|
5 |
+
level_buffer_capacity: 4000
|
6 |
+
replay_prob: 0.5
|
7 |
+
staleness_coeff: 0.3
|
8 |
+
temperature: 1.0
|
9 |
+
topk_k: 8
|
10 |
+
minimum_fill_ratio: 0.5
|
11 |
+
prioritization: rank
|
12 |
+
buffer_duplicate_check: false
|
13 |
+
buffer_train: false
|
14 |
+
mode: train
|
15 |
+
checkpoint_directory: checkpoints/physicsenv/ued
|
16 |
+
max_number_of_checkpoints: 5
|
Kinetix/configs/ued/plr.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use_accel: false
|
2 |
+
exploratory_grad_updates: true
|
3 |
+
num_edits: 2
|
4 |
+
score_function: MaxMC
|
5 |
+
level_buffer_capacity: 4000
|
6 |
+
replay_prob: 0.5
|
7 |
+
staleness_coeff: 0.3
|
8 |
+
temperature: 1.0
|
9 |
+
topk_k: 8
|
10 |
+
minimum_fill_ratio: 0.5
|
11 |
+
prioritization: rank
|
12 |
+
buffer_duplicate_check: false
|
13 |
+
buffer_train: false
|
14 |
+
mode: train
|
15 |
+
checkpoint_directory: checkpoints/physicsenv/ued
|
16 |
+
max_number_of_checkpoints: 5
|
17 |
+
accel_start_from_empty: True
|
Kinetix/configs/ued/sfl.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"sampled_envs_ratio": 0.5
|
2 |
+
"batch_size": 4096
|
3 |
+
"num_batches": 3
|
4 |
+
"rollout_steps": 512
|
5 |
+
"num_to_save": 1024
|
6 |
+
|
7 |
+
log_learnability_before_after: false
|
8 |
+
put_eval_levels_in_buffer: false
|
9 |
+
save_learnability_buffer_pickle: false
|
Kinetix/docs/README.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Documentation
|
2 |
+
This is intended to provide some more details about how Kinetix works, including more in-depth examples. If you are interested in the configuration options, see [here](./configs.md).
|
3 |
+
|
4 |
+
- [Documentation](#documentation)
|
5 |
+
- [Different Versions of Kinetix Environments](#different-versions-of-kinetix-environments)
|
6 |
+
- [Action Spaces](#action-spaces)
|
7 |
+
- [Observation Spaces](#observation-spaces)
|
8 |
+
- [Resetting Functionality](#resetting-functionality)
|
9 |
+
- [Using Kinetix to easily design your own JAX Environments](#using-kinetix-to-easily-design-your-own-jax-environments)
|
10 |
+
- [Step 1 - Design an Environment](#step-1---design-an-environment)
|
11 |
+
- [Step 2 - Export It](#step-2---export-it)
|
12 |
+
- [Step 3 - Import It](#step-3---import-it)
|
13 |
+
- [Step 4 - Train](#step-4---train)
|
14 |
+
|
15 |
+
|
16 |
+
## Different Versions of Kinetix Environments
|
17 |
+
We provide several different variations on the standard Kinetix environment, where the primary difference is the action and observation spaces.
|
18 |
+
|
19 |
+
Each of the environments has a different name, of the following form: `Kinetix-<OBS>-<ACTION>-v1`, and can be made using the `make_kinetix_env_from_name` helper function.
|
20 |
+
### Action Spaces
|
21 |
+
For all action spaces, the agent can control joints and thrusters. Joints have a property `motor_binding`, which is a way to tie different joints to the same action. Two joints that have the same binding will always perform the same action, likewise for thrusters.
|
22 |
+
|
23 |
+
We have three observation spaces, discrete, continuous and multi-discrete (which is the default).
|
24 |
+
- **Discrete** has `2 * num_motor_bindings + num_thruster_bindings + 1` options, one of which can be active at any time. There are two options for every joint, i.e., backward and forward at full power. There is one option for each thruster, to activate it at full power. The final option is a no-op, meaning that no torque or force is applied to joints/thrusters.
|
25 |
+
- **Continuous** has shape `num_motor_bindings + num_thruster_bindings`, where each motor element can take a value between -1 and 1, and thruster elements can take values between 0 and 1.
|
26 |
+
- **Multi-Discrete**: This is a discrete action space, but allows multiple joints and thrusters to be active at any one time. The agent must output a flat vector of size `3 * num_motor_bindings + 2 * num_thruster_bindings`. For joints, each group of three represents a categorical distribution of `[0, -1, +1]` and for thrusters it represents `[0, +1]`.
|
27 |
+
|
28 |
+
### Observation Spaces
|
29 |
+
We provide three primary observation spaces, Symbolic-Flat (called just symbolic), Symbolic-Entity (called entity, which is also the default) and Pixels.
|
30 |
+
- **Symbolic-Flat** returns a large vector, which is the flattened representation of all shapes and their properties.
|
31 |
+
- **Symbolic-Entity** also returns a vector representation of all entities, but does not flatten it, instead returning it in a form that can be used with permutation-invariant network architectures, such as transformers.
|
32 |
+
- **Pixels** returns an image representation of the scene. This is partially observable, as features such as the restitution and density of shapes is not shown.
|
33 |
+
|
34 |
+
|
35 |
+
Each observation space has its own pros and cons. **Symbolic-Flat** is the fastest by far, but has two clear downsides. First, it is restricted to a single environment size, e.g. a model trained on `small` cannot be run on `medium` levels. Second, due to the large number of symmetries (e.g. any permutation of the same shapes would represent the same scene but would look very different in this observation space), this generalises worse than *entity*.
|
36 |
+
|
37 |
+
**Symbolic-Entity** is faster than pixels, but slower than Symbolic-Flat. However, it can be applied to any number of shapes, and is natively permutation invariant. For these reasons we chose it as the default option.
|
38 |
+
|
39 |
+
Finally, **Pixels** runs the slowest, and also requires more memory, which means that we cannot run as many parallel environments. However, pixels is potentially the most general format, and could theoretically allow transfer to other domains and simulators.
|
40 |
+
|
41 |
+
|
42 |
+
## Resetting Functionality
|
43 |
+
We have two primary resetting functions that control the environment's behaviour when an episode ends. The first of these is to train on a known, predefined set of levels, and resetting samples a new level from this set. In the extreme case, this also allows training only on a single level in the standard RL manner. The other main way of resetting is to sample a *random* level from some distribution, meaning that it is exceedingly unlikely to sample the same level twice.
|
44 |
+
|
45 |
+
## Using Kinetix to easily design your own JAX Environments
|
46 |
+
Since Kinetix has a general physics engine, you can design your own environments and train RL agents on them very fast! This section in the docs describes this pipeline.
|
47 |
+
### Step 1 - Design an Environment
|
48 |
+
You can go to our [online editor](https://kinetix-env.github.io/gallery.html?editor=true). You can also have a look at the [gallery](https://kinetix-env.github.io/gallery.html) if you need some inspiration.
|
49 |
+
|
50 |
+
The following two images show the main editor page, and then the level I designed, where you have to spin the ball the right way. While designing the level, you can play it to test it out, seeing if it is possible and of the appropriate difficulty.
|
51 |
+
|
52 |
+
|
53 |
+
<p align="middle">
|
54 |
+
<img src="../images/docs/edit-1.png" width="49%" />
|
55 |
+
<img src="../images/docs/edit-2.png" width="49%" />
|
56 |
+
</p>
|
57 |
+
|
58 |
+
### Step 2 - Export It
|
59 |
+
Once you are satisfied with your level, you can download it as a json file by using the button on the bottom left. Once this is downloaded, move it to `$KINETIX_ROOT/worlds/custom/my_custom_level.json`, where `$KINETIX_ROOT` is the root of the Kinetix repo.
|
60 |
+
|
61 |
+
|
62 |
+
### Step 3 - Import It
|
63 |
+
In python, you can import the level as follows, see `examples/example_premade_level_replay.py` for an example.
|
64 |
+
```python
|
65 |
+
from kinetix.util.saving import load_from_json_file
|
66 |
+
level, static_env_params, env_params = load_from_json_file("worlds/custom/my_custom_level.json")
|
67 |
+
```
|
68 |
+
|
69 |
+
### Step 4 - Train
|
70 |
+
You can use the above if you want to import the level and play around with it. If you want to train an RL agent on this level, you can do the following (see [here](https://github.com/FLAIROx/Kinetix?tab=readme-ov-file#training-on-a-single-hand-designed-level) from the main README).
|
71 |
+
|
72 |
+
```commandline
|
73 |
+
python3 experiments/ppo.py env_size=custom \
|
74 |
+
env_size.custom_path=custom/my_custom_level.json \
|
75 |
+
train_levels=s \
|
76 |
+
train_levels.train_levels_list='["custom/my_custom_level.json"]' \
|
77 |
+
eval=eval_auto
|
78 |
+
```
|
79 |
+
|
80 |
+
And the agent will start training, with videos on this on [wandb](https://wandb.ai).
|
81 |
+
<p align="middle">
|
82 |
+
<img src="../images/docs/wandb.gif" width="49%" />
|
83 |
+
</p>
|
Kinetix/docs/configs.md
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration
|
2 |
+
|
3 |
+
- [Configuration](#configuration)
|
4 |
+
- [Configuration Headings](#configuration-headings)
|
5 |
+
- [Env](#env)
|
6 |
+
- [Env Size](#env-size)
|
7 |
+
- [Learning](#learning)
|
8 |
+
- [Misc](#misc)
|
9 |
+
- [Eval](#eval)
|
10 |
+
- [Eval Env Size](#eval-env-size)
|
11 |
+
- [Train Levels](#train-levels)
|
12 |
+
- [Model](#model)
|
13 |
+
- [UED](#ued)
|
14 |
+
|
15 |
+
|
16 |
+
We use [hydra](hydra.cc) for all of our configurations, and we use [hierarchical configuration](https://hydra.cc/docs/tutorials/structured_config/schema/) to organise everything better.
|
17 |
+
|
18 |
+
In particular, we have the following configuration headings, with the base `ppo` config looking like:
|
19 |
+
```yaml
|
20 |
+
defaults:
|
21 |
+
- env: entity
|
22 |
+
- env_size: s
|
23 |
+
- learning:
|
24 |
+
- ppo-base
|
25 |
+
- ppo-rnn
|
26 |
+
- misc: misc
|
27 |
+
- eval: s
|
28 |
+
- eval_env_size: s
|
29 |
+
- train_levels: random
|
30 |
+
- model:
|
31 |
+
- model-base
|
32 |
+
- model-transformer
|
33 |
+
- _self_
|
34 |
+
seed: 0
|
35 |
+
```
|
36 |
+
|
37 |
+
## Configuration Headings
|
38 |
+
### Env
|
39 |
+
This controls the environment to be used.
|
40 |
+
#### Preset Options
|
41 |
+
We provide two options in `configs/env`, namely `entity` and `symbolic`; each of these can be used by running `python3 experiments/ppo.py env=symbolic` or `python3 experiments/ppo.py env=entity`. If you wish to customise the options further, you can add any of the following subkeys (e.g. by running `python3 experiments/ppo.py env=symbolic env.dense_reward_scale=0.0`):
|
42 |
+
#### Individual Subkeys
|
43 |
+
- `env.env_name`: The name of the environment, with controls the observation and action space.
|
44 |
+
- `env.dense_reward_scale`: How large the dense reward scale is, set this to zero to disable dense rewards.
|
45 |
+
- `env.frame_skip`: The number of frames to skip, setting this to 2 (the default) seems to perform better.
|
46 |
+
### Env Size
|
47 |
+
This controls the maximum number of shapes present in the simulation. This has two important tradeoffs, namely speed and representational power: Small environments run much faster but some complex environments require a large number of shapes. See `configs/env_size`
|
48 |
+
#### Preset Options
|
49 |
+
- `s`: The `small` preset
|
50 |
+
- `m`: `Medium` preset
|
51 |
+
- `l`: `Large` preset
|
52 |
+
- `custom`: Allows the use of a custom environment size loaded from a json file (see [here](#train-levels) for more).
|
53 |
+
#### Individual Subkeys
|
54 |
+
- `num_polygons`: How many polygons
|
55 |
+
- `num_circles`: How many circles
|
56 |
+
- `num_joints`: How many joints
|
57 |
+
- `num_thrusters`: How many thrusters
|
58 |
+
- `env_size_name`: "s", "m" or "l"
|
59 |
+
- `num_motor_bindings`: How many different joint bindings are there, meaning how many different actions are there associated with joints. All joints with the same binding will have the same action applied to them.
|
60 |
+
- `num_thruster_bindings`: How many different thruster bindings are there
|
61 |
+
- `env_size_type`: "predefined" or "custom"
|
62 |
+
- `custom_path`: **Only for env_size_type=custom**, controls the json file to load the custom environment size from.
|
63 |
+
### Learning
|
64 |
+
This controls the agent's learning, see `configs/learning`
|
65 |
+
#### Preset Options
|
66 |
+
- `ppo-base`: This has all of the base PPO parameters, and is used by all methods
|
67 |
+
- `ppo-rnn`: This has the PureJaxRL settings for some of PPO's hyperparameters (mainly `num_steps` is different)
|
68 |
+
- `ppo-sfl`: This has the SFL-specific value of `num_steps`
|
69 |
+
- `ppo-ued`: This has the JAXUED-specific `num_steps` and `outer_rollout_steps`
|
70 |
+
#### Individual Subkeys
|
71 |
+
- `lr`: Learning Rate
|
72 |
+
- `anneal_lr`: Whether to anneal LR
|
73 |
+
- `warmup_lr`: Whether to warmup LR
|
74 |
+
- `peak_lr`: If warming up, the peak
|
75 |
+
- `initial_lr`: If warming up, the initial LR
|
76 |
+
- `warmup_frac`: If warming up, the warmup fraction of training time
|
77 |
+
- `max_grad_norm`: Maximum grad norm
|
78 |
+
- `total_timesteps`: How many total environment interactions must be run
|
79 |
+
- `num_train_envs`: Number of parallel environments to run simultaneously
|
80 |
+
- `num_minibatches`: Minibatches for PPO learning
|
81 |
+
- `gamma`: Discount factor
|
82 |
+
- `update_epochs`: PPO update epochs
|
83 |
+
- `clip_eps`: PPO clipping epsilon
|
84 |
+
- `gae_lambda`: PPO Lambda for GAE
|
85 |
+
- `ent_coef`: Entropy loss coefficient
|
86 |
+
- `vf_coef`: Value function loss coefficient
|
87 |
+
- `permute_state_during_training`: If true, the state is permuted on every reset.
|
88 |
+
- `filter_levels`: If true, and we are training on random levels, this filters out levels that can be solved by a no-op
|
89 |
+
- `level_filter_n_steps`: How many steps to allocate to the no-op policy for filtering
|
90 |
+
- `level_filter_sample_ratio`: How many more levels to sample than required (ideally `level_filter_sample_ratio` is more than the fraction that will be filtered out).
|
91 |
+
- `num_steps`: PPO rollout length
|
92 |
+
- `outer_rollout_steps`: How many learning steps to do for e.g. PLR for each rollout (see the [Craftax paper](https://arxiv.org/abs/2402.16801) for a more in-depth explanation).
|
93 |
+
### Misc
|
94 |
+
There are a plethora of miscellaneous options that are grouped under the `misc` category. There is only one preset option, `configs/misc/misc.yaml`.
|
95 |
+
#### Individual Subkeys
|
96 |
+
- `group`: Wandb group ("auto" usually works well)
|
97 |
+
- `group_auto_prefix`: If using group=auto, this is a user-defined prefix
|
98 |
+
- `save_path`: Where to save checkpoints to
|
99 |
+
- `use_wandb`: Should wandb be logged to
|
100 |
+
- `save_policy`: Should we save the policy
|
101 |
+
- `wandb_project`: Wandb project
|
102 |
+
- `wandb_entity`: Wandb entity, leave as `null` to use your default one
|
103 |
+
- `wandb_mode` : Wandb mode
|
104 |
+
- `video_frequency`: How often to log videos (they are quite large)
|
105 |
+
- `load_from_checkpoint`: WWandb artifact path to load from
|
106 |
+
- `load_only_params`: Whether to load just the network parameters or entire train state.
|
107 |
+
- `checkpoint_save_freq`: How often to log checkpoits
|
108 |
+
- `checkpoint_human_numbers`: Should the checkpoints have human-readable timestep numbers
|
109 |
+
- `load_legacy_checkpoint`: Do not use
|
110 |
+
- `load_train_levels_legacy`: Do not use
|
111 |
+
- `economical_saving`: If true, only saves a few important checkpoints for space conservation purposes.
|
112 |
+
### Eval
|
113 |
+
This option (see `configs/eval`) controls how evaluation works, and what levels are used.
|
114 |
+
#### Preset Options
|
115 |
+
- `s`: Eval on the `s` hand-designed levels located in `worlds/s`
|
116 |
+
- `m`: Eval on the `m` hand-designed levels located in `worlds/m`
|
117 |
+
- `l`: Eval on the `l` hand-designed levels located in `worlds/l`
|
118 |
+
- `eval_all`: Eval on all of the hand-designed eval levels
|
119 |
+
- `eval_auto`: If `train_levels` is not random, evaluate on the training levels.
|
120 |
+
- `mujoco`: Eval on the recreations of the mujoco tasks.
|
121 |
+
- `eval_general`: General option if you are planning on overwriting most options.
|
122 |
+
#### Individual Subkeys
|
123 |
+
- `eval_levels`: List of eval levels or the string "auto"
|
124 |
+
- `eval_num_attempts`: How many times to eval on the same level
|
125 |
+
- `eval_freq`: How often to evaluate
|
126 |
+
- `EVAL_ON_SAMPLED`: If true, in `plr.py` and `sfl.py`, evaluates on a fixed set of randomly-generated levels
|
127 |
+
|
128 |
+
### Eval Env Size
|
129 |
+
This controls the size of the evaluation environment. This is crucial to match up with the size of the evaluation levels.
|
130 |
+
#### Preset Options
|
131 |
+
- `s`: Same as the `env_size` option.
|
132 |
+
- `m`: Same as the `env_size` option.
|
133 |
+
- `l`: Same as the `env_size` option.
|
134 |
+
### Train Levels
|
135 |
+
Which levels to train on.
|
136 |
+
#### Preset Options
|
137 |
+
- `s`: All of the `s` holdout levels
|
138 |
+
- `m`: All of the `m` holdout levels
|
139 |
+
- `l`: All of the `l` holdout levels
|
140 |
+
- `train_all`: All of the levels from all 3 holdout sets
|
141 |
+
- `mujoco`: All of the mujoco recreation levels.
|
142 |
+
- `random`: Train on random levels
|
143 |
+
#### Individual Subkeys
|
144 |
+
- `train_level_mode`: "random" or "list"
|
145 |
+
- `train_level_distribution`: if train_level_mode=random, this controls which distribution to use. By default `distribution_v3`
|
146 |
+
- `train_levels_list`: This is a list of levels to train on.
|
147 |
+
### Model
|
148 |
+
This controls the model architecture and options associated with that.
|
149 |
+
#### Preset Options
|
150 |
+
We use both of the following:
|
151 |
+
- `model-base`
|
152 |
+
- `model-entity`
|
153 |
+
#### Individual Subkeys
|
154 |
+
`fc_layer_depth`: How many layers in the FC model
|
155 |
+
`fc_layer_width`: How wide is each FC layer
|
156 |
+
`activation`: NN activation
|
157 |
+
`recurrent_model`: Whether or not to use recurrence
|
158 |
+
The following are just relevant when using `env=entity`
|
159 |
+
`transformer_depth`: How many transformer layers to use
|
160 |
+
`transformer_size`: How large are the KQV vectors
|
161 |
+
`transformer_encoder_size`: How large are the initial embeddings
|
162 |
+
`num_heads`: How many heads, must be a multiple of 4 and divide `transformer_size` evenly.
|
163 |
+
`full_attention_mask`: If true, all heads use the full attention mask
|
164 |
+
`aggregate_mode`: `dummy_and_mean` works well.
|
165 |
+
### UED
|
166 |
+
Options pertaining to UED (i.e., when using the scripts `plr.py` or `sfl.py`)
|
167 |
+
#### Preset Options
|
168 |
+
- `sfl`
|
169 |
+
- `plr`
|
170 |
+
- `accel`
|
171 |
+
#### Individual Subkeys
|
172 |
+
See the individual files for the configuration options used.
|
173 |
+
For SFL, we have:
|
174 |
+
|
175 |
+
- `sampled_envs_ratio`: How many environments are from the SFL buffer and how many are randomly generated
|
176 |
+
- `batch_size`: How many levels to evaluate learnability on per batch
|
177 |
+
- `num_batches`: How many batches to run when choosing the most learnable levels
|
178 |
+
- `rollout_steps`: How many steps to rollout for when doing the learnability calculation.
|
179 |
+
- `num_to_save`: How many levels to save in the learnability buffer
|
Kinetix/examples/example_premade_level_replay.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import jax.random
|
4 |
+
from jax2d.engine import PhysicsEngine
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
|
7 |
+
from kinetix.environment.env import make_kinetix_env_from_args
|
8 |
+
from kinetix.environment.env_state import StaticEnvParams, EnvParams
|
9 |
+
from kinetix.environment.ued.distributions import sample_kinetix_level
|
10 |
+
from kinetix.environment.ued.ued_state import UEDParams
|
11 |
+
from kinetix.render.renderer_pixels import make_render_pixels
|
12 |
+
from kinetix.util.saving import load_from_json_file
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# Load a premade level
|
17 |
+
level, static_env_params, env_params = load_from_json_file("worlds/l/grasp_easy.json")
|
18 |
+
|
19 |
+
# Create the environment
|
20 |
+
env = make_kinetix_env_from_args(
|
21 |
+
obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
|
22 |
+
)
|
23 |
+
|
24 |
+
# Reset the environment state to this level
|
25 |
+
rng = jax.random.PRNGKey(0)
|
26 |
+
rng, _rng = jax.random.split(rng)
|
27 |
+
obs, env_state = env.reset_to_level(_rng, level, env_params)
|
28 |
+
|
29 |
+
# Take a step in the environment
|
30 |
+
rng, _rng = jax.random.split(rng)
|
31 |
+
action = env.action_space(env_params).sample(_rng)
|
32 |
+
rng, _rng = jax.random.split(rng)
|
33 |
+
obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
|
34 |
+
|
35 |
+
# Render environment
|
36 |
+
renderer = make_render_pixels(env_params, static_env_params)
|
37 |
+
|
38 |
+
# There are a lot of wrappers
|
39 |
+
pixels = renderer(env_state.env_state.env_state.env_state)
|
40 |
+
|
41 |
+
plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
|
42 |
+
plt.show()
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
main()
|
Kinetix/examples/example_random_level_replay.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import jax.random
|
4 |
+
from jax2d.engine import PhysicsEngine
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
|
7 |
+
from kinetix.environment.env import make_kinetix_env_from_args
|
8 |
+
from kinetix.environment.env_state import StaticEnvParams, EnvParams
|
9 |
+
from kinetix.environment.ued.distributions import sample_kinetix_level
|
10 |
+
from kinetix.environment.ued.ued_state import UEDParams
|
11 |
+
from kinetix.render.renderer_pixels import make_render_pixels
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
# Use default parameters
|
16 |
+
env_params = EnvParams()
|
17 |
+
static_env_params = StaticEnvParams()
|
18 |
+
ued_params = UEDParams()
|
19 |
+
|
20 |
+
# Create the environment
|
21 |
+
env = make_kinetix_env_from_args(
|
22 |
+
obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
|
23 |
+
)
|
24 |
+
|
25 |
+
# Sample a random level
|
26 |
+
rng = jax.random.PRNGKey(0)
|
27 |
+
rng, _rng = jax.random.split(rng)
|
28 |
+
level = sample_kinetix_level(_rng, env.physics_engine, env_params, static_env_params, ued_params)
|
29 |
+
|
30 |
+
# Reset the environment state to this level
|
31 |
+
rng, _rng = jax.random.split(rng)
|
32 |
+
obs, env_state = env.reset_to_level(_rng, level, env_params)
|
33 |
+
|
34 |
+
# Take a step in the environment
|
35 |
+
rng, _rng = jax.random.split(rng)
|
36 |
+
action = env.action_space(env_params).sample(_rng)
|
37 |
+
rng, _rng = jax.random.split(rng)
|
38 |
+
obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
|
39 |
+
|
40 |
+
# Render environment
|
41 |
+
renderer = make_render_pixels(env_params, static_env_params)
|
42 |
+
|
43 |
+
# There are a lot of wrappers
|
44 |
+
pixels = renderer(env_state.env_state.env_state.env_state)
|
45 |
+
|
46 |
+
plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
|
47 |
+
plt.show()
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
Kinetix/experiments/plr.py
ADDED
@@ -0,0 +1,1143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import time
|
3 |
+
from enum import IntEnum
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import chex
|
7 |
+
import hydra
|
8 |
+
import jax
|
9 |
+
import jax.numpy as jnp
|
10 |
+
import numpy as np
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
import optax
|
13 |
+
from flax import core, struct
|
14 |
+
from flax.training.train_state import TrainState as BaseTrainState
|
15 |
+
|
16 |
+
import wandb
|
17 |
+
from kinetix.environment.ued.distributions import (
|
18 |
+
create_random_starting_distribution,
|
19 |
+
)
|
20 |
+
from kinetix.environment.ued.ued import (
|
21 |
+
make_mutate_env,
|
22 |
+
make_reset_train_function_with_mutations,
|
23 |
+
make_vmapped_filtered_level_sampler,
|
24 |
+
)
|
25 |
+
from kinetix.environment.ued.ued import (
|
26 |
+
make_mutate_env,
|
27 |
+
make_reset_train_function_with_list_of_levels,
|
28 |
+
make_reset_train_function_with_mutations,
|
29 |
+
)
|
30 |
+
from kinetix.util.config import (
|
31 |
+
generate_ued_params_from_config,
|
32 |
+
get_video_frequency,
|
33 |
+
init_wandb,
|
34 |
+
normalise_config,
|
35 |
+
save_data_to_local_file,
|
36 |
+
generate_params_from_config,
|
37 |
+
get_eval_level_groups,
|
38 |
+
)
|
39 |
+
from jaxued.environments.underspecified_env import EnvState
|
40 |
+
from jaxued.level_sampler import LevelSampler
|
41 |
+
from jaxued.utils import compute_max_returns, max_mc, positive_value_loss
|
42 |
+
from flax.serialization import to_state_dict
|
43 |
+
|
44 |
+
import sys
|
45 |
+
|
46 |
+
sys.path.append("experiments")
|
47 |
+
from kinetix.environment.env import make_kinetix_env_from_name
|
48 |
+
from kinetix.environment.env_state import StaticEnvParams
|
49 |
+
from kinetix.environment.wrappers import (
|
50 |
+
UnderspecifiedToGymnaxWrapper,
|
51 |
+
LogWrapper,
|
52 |
+
DenseRewardWrapper,
|
53 |
+
AutoReplayWrapper,
|
54 |
+
)
|
55 |
+
from kinetix.models import make_network_from_config
|
56 |
+
from kinetix.render.renderer_pixels import make_render_pixels
|
57 |
+
from kinetix.models.actor_critic import ScannedRNN
|
58 |
+
from kinetix.util.learning import (
|
59 |
+
general_eval,
|
60 |
+
get_eval_levels,
|
61 |
+
no_op_and_random_rollout,
|
62 |
+
sample_trajectories_and_learn,
|
63 |
+
)
|
64 |
+
from kinetix.util.saving import (
|
65 |
+
load_train_state_from_wandb_artifact_path,
|
66 |
+
save_model_to_wandb,
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class UpdateState(IntEnum):
|
71 |
+
DR = 0
|
72 |
+
REPLAY = 1
|
73 |
+
MUTATE = 2
|
74 |
+
|
75 |
+
|
76 |
+
def get_level_complexity_metrics(all_levels: EnvState, static_env_params: StaticEnvParams):
|
77 |
+
def get_for_single_level(level):
|
78 |
+
return {
|
79 |
+
"complexity/num_shapes": level.polygon.active[static_env_params.num_static_fixated_polys :].sum()
|
80 |
+
+ level.circle.active.sum(),
|
81 |
+
"complexity/num_joints": level.joint.active.sum(),
|
82 |
+
"complexity/num_thrusters": level.thruster.active.sum(),
|
83 |
+
"complexity/num_rjoints": (level.joint.active * jnp.logical_not(level.joint.is_fixed_joint)).sum(),
|
84 |
+
"complexity/num_fjoints": (level.joint.active * (level.joint.is_fixed_joint)).sum(),
|
85 |
+
"complexity/has_ball": ((level.polygon_shape_roles == 1) * level.polygon.active).sum()
|
86 |
+
+ ((level.circle_shape_roles == 1) * level.circle.active).sum(),
|
87 |
+
"complexity/has_goal": ((level.polygon_shape_roles == 2) * level.polygon.active).sum()
|
88 |
+
+ ((level.circle_shape_roles == 2) * level.circle.active).sum(),
|
89 |
+
}
|
90 |
+
|
91 |
+
return jax.tree.map(lambda x: x.mean(), jax.vmap(get_for_single_level)(all_levels))
|
92 |
+
|
93 |
+
|
94 |
+
def get_ued_score_metrics(all_ued_scores):
|
95 |
+
(mc, pvl, learn) = all_ued_scores
|
96 |
+
scores = {}
|
97 |
+
for score, name in zip([mc, pvl, learn], ["MaxMC", "PVL", "Learnability"]):
|
98 |
+
scores[f"ued_scores/{name}/Mean"] = score.mean()
|
99 |
+
scores[f"ued_scores_additional/{name}/Max"] = score.max()
|
100 |
+
scores[f"ued_scores_additional/{name}/Min"] = score.min()
|
101 |
+
|
102 |
+
return scores
|
103 |
+
|
104 |
+
|
105 |
+
class TrainState(BaseTrainState):
|
106 |
+
sampler: core.FrozenDict[str, chex.ArrayTree] = struct.field(pytree_node=True)
|
107 |
+
update_state: UpdateState = struct.field(pytree_node=True)
|
108 |
+
# === Below is used for logging ===
|
109 |
+
num_dr_updates: int
|
110 |
+
num_replay_updates: int
|
111 |
+
num_mutation_updates: int
|
112 |
+
|
113 |
+
dr_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
|
114 |
+
replay_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
|
115 |
+
mutation_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
|
116 |
+
|
117 |
+
dr_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
|
118 |
+
replay_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
|
119 |
+
mutation_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
|
120 |
+
|
121 |
+
dr_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
|
122 |
+
replay_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
|
123 |
+
mutation_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
|
124 |
+
|
125 |
+
|
126 |
+
# region PPO helper functions
|
127 |
+
|
128 |
+
# endregion
|
129 |
+
|
130 |
+
|
131 |
+
def train_state_to_log_dict(train_state: TrainState, level_sampler: LevelSampler) -> dict:
|
132 |
+
"""To prevent the entire (large) train_state to be copied to the CPU when doing logging, this function returns all of the important information in a dictionary format.
|
133 |
+
|
134 |
+
Anything in the `log` key will be logged to wandb.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
train_state (TrainState):
|
138 |
+
level_sampler (LevelSampler):
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
dict:
|
142 |
+
"""
|
143 |
+
sampler = train_state.sampler
|
144 |
+
idx = jnp.arange(level_sampler.capacity) < sampler["size"]
|
145 |
+
s = jnp.maximum(idx.sum(), 1)
|
146 |
+
return {
|
147 |
+
"log": {
|
148 |
+
"level_sampler/size": sampler["size"],
|
149 |
+
"level_sampler/episode_count": sampler["episode_count"],
|
150 |
+
"level_sampler/max_score": sampler["scores"].max(),
|
151 |
+
"level_sampler/weighted_score": (sampler["scores"] * level_sampler.level_weights(sampler)).sum(),
|
152 |
+
"level_sampler/mean_score": (sampler["scores"] * idx).sum() / s,
|
153 |
+
},
|
154 |
+
"info": {
|
155 |
+
"num_dr_updates": train_state.num_dr_updates,
|
156 |
+
"num_replay_updates": train_state.num_replay_updates,
|
157 |
+
"num_mutation_updates": train_state.num_mutation_updates,
|
158 |
+
},
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
def compute_learnability(config, done, reward, info, num_envs):
|
163 |
+
num_agents = 1
|
164 |
+
BATCH_ACTORS = num_envs * num_agents
|
165 |
+
|
166 |
+
rollout_length = config["num_steps"] * config["outer_rollout_steps"]
|
167 |
+
|
168 |
+
@partial(jax.vmap, in_axes=(None, 1, 1, 1))
|
169 |
+
@partial(jax.jit, static_argnums=(0,))
|
170 |
+
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
|
171 |
+
idxs = jnp.arange(max_steps)
|
172 |
+
|
173 |
+
@partial(jax.vmap, in_axes=(0, 0))
|
174 |
+
def __ep_outcomes(start_idx, end_idx):
|
175 |
+
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
|
176 |
+
r = jnp.sum(returns * mask)
|
177 |
+
goal_r = info["GoalR"]
|
178 |
+
success = jnp.sum(goal_r * mask)
|
179 |
+
collision = 0
|
180 |
+
timeo = 0
|
181 |
+
l = end_idx - start_idx
|
182 |
+
return r, success, collision, timeo, l
|
183 |
+
|
184 |
+
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
|
185 |
+
mask_done = jnp.where(done_idxs == max_steps, 0, 1)
|
186 |
+
ep_return, success, collision, timeo, length = __ep_outcomes(
|
187 |
+
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
|
188 |
+
)
|
189 |
+
|
190 |
+
return {
|
191 |
+
"ep_return": ep_return.mean(where=mask_done),
|
192 |
+
"num_episodes": mask_done.sum(),
|
193 |
+
"num_success": success.sum(where=mask_done),
|
194 |
+
"success_rate": success.mean(where=mask_done),
|
195 |
+
"collision_rate": collision.mean(where=mask_done),
|
196 |
+
"timeout_rate": timeo.mean(where=mask_done),
|
197 |
+
"ep_len": length.mean(where=mask_done),
|
198 |
+
}
|
199 |
+
|
200 |
+
done_by_env = done.reshape((-1, num_agents, num_envs))
|
201 |
+
reward_by_env = reward.reshape((-1, num_agents, num_envs))
|
202 |
+
o = _calc_outcomes_by_agent(rollout_length, done, reward, info)
|
203 |
+
success_by_env = o["success_rate"].reshape((num_agents, num_envs))
|
204 |
+
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
|
205 |
+
|
206 |
+
return (
|
207 |
+
learnability_by_env,
|
208 |
+
o["num_episodes"].reshape(num_agents, num_envs).sum(axis=0),
|
209 |
+
o["num_success"].reshape(num_agents, num_envs).T,
|
210 |
+
) # so agents is at the end.
|
211 |
+
|
212 |
+
|
213 |
+
def compute_score(
|
214 |
+
config: dict, dones: chex.Array, values: chex.Array, max_returns: chex.Array, reward, info, advantages: chex.Array
|
215 |
+
) -> chex.Array:
|
216 |
+
# Computes the score for each level
|
217 |
+
if config["score_function"] == "MaxMC":
|
218 |
+
return max_mc(dones, values, max_returns)
|
219 |
+
elif config["score_function"] == "pvl":
|
220 |
+
return positive_value_loss(dones, advantages)
|
221 |
+
elif config["score_function"] == "learnability":
|
222 |
+
learnability, num_episodes, num_success = compute_learnability(
|
223 |
+
config, dones, reward, info, config["num_train_envs"]
|
224 |
+
)
|
225 |
+
return learnability
|
226 |
+
else:
|
227 |
+
raise ValueError(f"Unknown score function: {config['score_function']}")
|
228 |
+
|
229 |
+
|
230 |
+
def compute_all_scores(
|
231 |
+
config: dict,
|
232 |
+
dones: chex.Array,
|
233 |
+
values: chex.Array,
|
234 |
+
max_returns: chex.Array,
|
235 |
+
reward,
|
236 |
+
info,
|
237 |
+
advantages: chex.Array,
|
238 |
+
return_success_rate=False,
|
239 |
+
):
|
240 |
+
mc = max_mc(dones, values, max_returns)
|
241 |
+
pvl = positive_value_loss(dones, advantages)
|
242 |
+
learnability, num_episodes, num_success = compute_learnability(
|
243 |
+
config, dones, reward, info, config["num_train_envs"]
|
244 |
+
)
|
245 |
+
if config["score_function"] == "MaxMC":
|
246 |
+
main_score = mc
|
247 |
+
elif config["score_function"] == "pvl":
|
248 |
+
main_score = pvl
|
249 |
+
elif config["score_function"] == "learnability":
|
250 |
+
main_score = learnability
|
251 |
+
else:
|
252 |
+
raise ValueError(f"Unknown score function: {config['score_function']}")
|
253 |
+
if return_success_rate:
|
254 |
+
success_rate = num_success.squeeze(1) / jnp.maximum(num_episodes, 1)
|
255 |
+
return main_score, (mc, pvl, learnability, success_rate)
|
256 |
+
return main_score, (mc, pvl, learnability)
|
257 |
+
|
258 |
+
|
259 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="plr")
|
260 |
+
def main(config=None):
|
261 |
+
my_name = "PLR"
|
262 |
+
config = OmegaConf.to_container(config)
|
263 |
+
if config["ued"]["replay_prob"] == 0.0:
|
264 |
+
my_name = "DR"
|
265 |
+
elif config["ued"]["use_accel"]:
|
266 |
+
my_name = "ACCEL"
|
267 |
+
|
268 |
+
time_start = time.time()
|
269 |
+
config = normalise_config(config, my_name)
|
270 |
+
env_params, static_env_params = generate_params_from_config(config)
|
271 |
+
config["env_params"] = to_state_dict(env_params)
|
272 |
+
config["static_env_params"] = to_state_dict(static_env_params)
|
273 |
+
|
274 |
+
run = init_wandb(config, my_name)
|
275 |
+
config = wandb.config
|
276 |
+
time_prev = time.time()
|
277 |
+
|
278 |
+
def log_eval(stats, train_state_info):
|
279 |
+
nonlocal time_prev
|
280 |
+
print(f"Logging update: {stats['update_count']}")
|
281 |
+
total_loss = jnp.mean(stats["losses"][0])
|
282 |
+
if jnp.isnan(total_loss):
|
283 |
+
print("NaN loss, skipping logging")
|
284 |
+
raise ValueError("NaN loss")
|
285 |
+
|
286 |
+
# generic stats
|
287 |
+
env_steps = int(
|
288 |
+
int(stats["update_count"]) * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
|
289 |
+
)
|
290 |
+
env_steps_delta = (
|
291 |
+
config["eval_freq"] * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
|
292 |
+
)
|
293 |
+
time_now = time.time()
|
294 |
+
log_dict = {
|
295 |
+
"timing/num_updates": stats["update_count"],
|
296 |
+
"timing/num_env_steps": env_steps,
|
297 |
+
"timing/sps": env_steps_delta / (time_now - time_prev),
|
298 |
+
"timing/sps_agg": env_steps / (time_now - time_start),
|
299 |
+
"loss/total_loss": jnp.mean(stats["losses"][0]),
|
300 |
+
"loss/value_loss": jnp.mean(stats["losses"][1][0]),
|
301 |
+
"loss/policy_loss": jnp.mean(stats["losses"][1][1]),
|
302 |
+
"loss/entropy_loss": jnp.mean(stats["losses"][1][2]),
|
303 |
+
}
|
304 |
+
time_prev = time_now
|
305 |
+
|
306 |
+
# evaluation performance
|
307 |
+
|
308 |
+
returns = stats["eval_returns"]
|
309 |
+
log_dict.update({"eval/mean_eval_return": returns.mean()})
|
310 |
+
log_dict.update({"eval/mean_eval_learnability": stats["eval_learn"].mean()})
|
311 |
+
log_dict.update({"eval/mean_eval_solve_rate": stats["eval_solves"].mean()})
|
312 |
+
log_dict.update({"eval/mean_eval_eplen": stats["eval_ep_lengths"].mean()})
|
313 |
+
for i in range(config["num_eval_levels"]):
|
314 |
+
log_dict[f"eval_avg_return/{config['eval_levels'][i]}"] = returns[i]
|
315 |
+
log_dict[f"eval_avg_learnability/{config['eval_levels'][i]}"] = stats["eval_learn"][i]
|
316 |
+
log_dict[f"eval_avg_solve_rate/{config['eval_levels'][i]}"] = stats["eval_solves"][i]
|
317 |
+
log_dict[f"eval_avg_episode_length/{config['eval_levels'][i]}"] = stats["eval_ep_lengths"][i]
|
318 |
+
log_dict[f"eval_get_max_eplen/{config['eval_levels'][i]}"] = stats["eval_get_max_eplen"][i]
|
319 |
+
log_dict[f"episode_return_bigger_than_negative/{config['eval_levels'][i]}"] = stats[
|
320 |
+
"episode_return_bigger_than_negative"
|
321 |
+
][i]
|
322 |
+
|
323 |
+
def _aggregate_per_size(values, name):
|
324 |
+
to_return = {}
|
325 |
+
for group_name, indices in eval_group_indices.items():
|
326 |
+
to_return[f"{name}_{group_name}"] = values[indices].mean()
|
327 |
+
return to_return
|
328 |
+
|
329 |
+
log_dict.update(_aggregate_per_size(returns, "eval_aggregate/return"))
|
330 |
+
log_dict.update(_aggregate_per_size(stats["eval_solves"], "eval_aggregate/solve_rate"))
|
331 |
+
|
332 |
+
if config["EVAL_ON_SAMPLED"]:
|
333 |
+
log_dict.update({"eval/mean_eval_return_sampled": stats["eval_dr_returns"].mean()})
|
334 |
+
log_dict.update({"eval/mean_eval_solve_rate_sampled": stats["eval_dr_solve_rates"].mean()})
|
335 |
+
log_dict.update({"eval/mean_eval_eplen_sampled": stats["eval_dr_eplen"].mean()})
|
336 |
+
|
337 |
+
# level sampler
|
338 |
+
log_dict.update(train_state_info["log"])
|
339 |
+
|
340 |
+
# images
|
341 |
+
log_dict.update(
|
342 |
+
{
|
343 |
+
"images/highest_scoring_level": wandb.Image(
|
344 |
+
np.array(stats["highest_scoring_level"]), caption="Highest scoring level"
|
345 |
+
)
|
346 |
+
}
|
347 |
+
)
|
348 |
+
log_dict.update(
|
349 |
+
{
|
350 |
+
"images/highest_weighted_level": wandb.Image(
|
351 |
+
np.array(stats["highest_weighted_level"]), caption="Highest weighted level"
|
352 |
+
)
|
353 |
+
}
|
354 |
+
)
|
355 |
+
|
356 |
+
for s in ["dr", "replay", "mutation"]:
|
357 |
+
if train_state_info["info"][f"num_{s}_updates"] > 0:
|
358 |
+
log_dict.update(
|
359 |
+
{
|
360 |
+
f"images/{s}_levels": [
|
361 |
+
wandb.Image(np.array(image), caption=f"{score}")
|
362 |
+
for image, score in zip(stats[f"{s}_levels"], stats[f"{s}_scores"])
|
363 |
+
]
|
364 |
+
}
|
365 |
+
)
|
366 |
+
if stats["log_videos"]:
|
367 |
+
# animations
|
368 |
+
rollout_ep = stats[f"{s}_ep_len"]
|
369 |
+
arr = np.array(stats[f"{s}_rollout"][:rollout_ep])
|
370 |
+
log_dict.update(
|
371 |
+
{
|
372 |
+
f"media/{s}_eval": wandb.Video(
|
373 |
+
arr.astype(np.uint8), fps=15, caption=f"{s.capitalize()} (len {rollout_ep})"
|
374 |
+
)
|
375 |
+
}
|
376 |
+
)
|
377 |
+
# * 255
|
378 |
+
|
379 |
+
# DR, Replay and Mutate Returns
|
380 |
+
dr_inds = (stats["update_state"] == UpdateState.DR).nonzero()[0]
|
381 |
+
rep_inds = (stats["update_state"] == UpdateState.REPLAY).nonzero()[0]
|
382 |
+
mut_inds = (stats["update_state"] == UpdateState.MUTATE).nonzero()[0]
|
383 |
+
|
384 |
+
for name, inds in [
|
385 |
+
("DR", dr_inds),
|
386 |
+
("REPLAY", rep_inds),
|
387 |
+
("MUTATION", mut_inds),
|
388 |
+
]:
|
389 |
+
if len(inds) > 0:
|
390 |
+
log_dict.update(
|
391 |
+
{
|
392 |
+
f"{name}/episode_return": stats["episode_return"][inds].mean(),
|
393 |
+
f"{name}/mean_eplen": stats["returned_episode_lengths"][inds].mean(),
|
394 |
+
f"{name}/mean_success": stats["returned_episode_solved"][inds].mean(),
|
395 |
+
f"{name}/noop_return": stats["noop_returns"][inds].mean(),
|
396 |
+
f"{name}/noop_eplen": stats["noop_eplen"][inds].mean(),
|
397 |
+
f"{name}/noop_success": stats["noop_success"][inds].mean(),
|
398 |
+
f"{name}/random_return": stats["random_returns"][inds].mean(),
|
399 |
+
f"{name}/random_eplen": stats["random_eplen"][inds].mean(),
|
400 |
+
f"{name}/random_success": stats["random_success"][inds].mean(),
|
401 |
+
}
|
402 |
+
)
|
403 |
+
for k in stats:
|
404 |
+
if "complexity/" in k:
|
405 |
+
k2 = "complexity/" + name + "_" + k.replace("complexity/", "")
|
406 |
+
log_dict.update({k2: stats[k][inds].mean()})
|
407 |
+
if "ued_scores/" in k:
|
408 |
+
k2 = "ued_scores/" + name + "_" + k.replace("ued_scores/", "")
|
409 |
+
log_dict.update({k2: stats[k][inds].mean()})
|
410 |
+
|
411 |
+
# Eval rollout animations
|
412 |
+
if stats["log_videos"]:
|
413 |
+
for i in range((config["num_eval_levels"])):
|
414 |
+
frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
|
415 |
+
frames = np.array(frames[:episode_length])
|
416 |
+
log_dict.update(
|
417 |
+
{
|
418 |
+
f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
|
419 |
+
frames.astype(np.uint8), fps=15, caption=f"Len ({episode_length})"
|
420 |
+
)
|
421 |
+
}
|
422 |
+
)
|
423 |
+
|
424 |
+
wandb.log(log_dict)
|
425 |
+
|
426 |
+
def get_all_metrics(
|
427 |
+
rng,
|
428 |
+
losses,
|
429 |
+
info,
|
430 |
+
init_env_state,
|
431 |
+
init_obs,
|
432 |
+
dones,
|
433 |
+
grads,
|
434 |
+
all_ued_scores,
|
435 |
+
new_levels,
|
436 |
+
):
|
437 |
+
noop_returns, noop_len, noop_success, random_returns, random_lens, random_success = no_op_and_random_rollout(
|
438 |
+
env,
|
439 |
+
env_params,
|
440 |
+
rng,
|
441 |
+
init_obs,
|
442 |
+
init_env_state,
|
443 |
+
config["num_train_envs"],
|
444 |
+
config["num_steps"] * config["outer_rollout_steps"],
|
445 |
+
)
|
446 |
+
metrics = (
|
447 |
+
{
|
448 |
+
"losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),
|
449 |
+
"returned_episode_lengths": (info["returned_episode_lengths"] * dones).sum()
|
450 |
+
/ jnp.maximum(1, dones.sum()),
|
451 |
+
"max_episode_length": info["returned_episode_lengths"].max(),
|
452 |
+
"levels_played": init_env_state.env_state.env_state,
|
453 |
+
"episode_return": (info["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
|
454 |
+
"episode_return_v2": (info["returned_episode_returns"] * info["returned_episode"]).sum()
|
455 |
+
/ jnp.maximum(1, info["returned_episode"].sum()),
|
456 |
+
"grad_norms": grads.mean(),
|
457 |
+
"noop_returns": noop_returns,
|
458 |
+
"noop_eplen": noop_len,
|
459 |
+
"noop_success": noop_success,
|
460 |
+
"random_returns": random_returns,
|
461 |
+
"random_eplen": random_lens,
|
462 |
+
"random_success": random_success,
|
463 |
+
"returned_episode_solved": (info["returned_episode_solved"] * dones).sum()
|
464 |
+
/ jnp.maximum(1, dones.sum()),
|
465 |
+
}
|
466 |
+
| get_level_complexity_metrics(new_levels, static_env_params)
|
467 |
+
| get_ued_score_metrics(all_ued_scores)
|
468 |
+
)
|
469 |
+
return metrics
|
470 |
+
|
471 |
+
# Setup the environment.
|
472 |
+
def make_env(static_env_params):
|
473 |
+
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
|
474 |
+
env = AutoReplayWrapper(env)
|
475 |
+
env = UnderspecifiedToGymnaxWrapper(env)
|
476 |
+
env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
|
477 |
+
env = LogWrapper(env)
|
478 |
+
return env
|
479 |
+
|
480 |
+
env = make_env(static_env_params)
|
481 |
+
|
482 |
+
if config["train_level_mode"] == "list":
|
483 |
+
sample_random_level = make_reset_train_function_with_list_of_levels(
|
484 |
+
config, config["train_levels_list"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
|
485 |
+
)
|
486 |
+
elif config["train_level_mode"] == "random":
|
487 |
+
sample_random_level = make_reset_train_function_with_mutations(
|
488 |
+
env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
|
489 |
+
)
|
490 |
+
else:
|
491 |
+
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
|
492 |
+
|
493 |
+
if config["use_accel"] and config["accel_start_from_empty"]:
|
494 |
+
|
495 |
+
def make_sample_random_level():
|
496 |
+
def inner(rng):
|
497 |
+
def _inner_accel(rng):
|
498 |
+
return create_random_starting_distribution(
|
499 |
+
rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=True
|
500 |
+
)
|
501 |
+
|
502 |
+
def _inner_accel_not_controllable(rng):
|
503 |
+
return create_random_starting_distribution(
|
504 |
+
rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=False
|
505 |
+
)
|
506 |
+
|
507 |
+
rng, _rng = jax.random.split(rng)
|
508 |
+
return _inner_accel(_rng)
|
509 |
+
|
510 |
+
return inner
|
511 |
+
|
512 |
+
sample_random_level = make_sample_random_level()
|
513 |
+
|
514 |
+
sample_random_levels = make_vmapped_filtered_level_sampler(
|
515 |
+
sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
|
516 |
+
)
|
517 |
+
|
518 |
+
def generate_world():
|
519 |
+
raise NotImplementedError
|
520 |
+
pass
|
521 |
+
|
522 |
+
def generate_eval_world(rng, env_params, static_env_params, level_idx):
|
523 |
+
# jax.random.split(jax.random.PRNGKey(101), num_levels), env_params, static_env_params, jnp.arange(num_levels)
|
524 |
+
|
525 |
+
raise NotImplementedError
|
526 |
+
|
527 |
+
_, eval_static_env_params = generate_params_from_config(
|
528 |
+
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
|
529 |
+
)
|
530 |
+
eval_env = make_env(eval_static_env_params)
|
531 |
+
ued_params = generate_ued_params_from_config(config)
|
532 |
+
|
533 |
+
mutate_world = make_mutate_env(static_env_params, env_params, ued_params)
|
534 |
+
|
535 |
+
def make_render_fn(static_env_params):
|
536 |
+
render_fn_inner = make_render_pixels(env_params, static_env_params)
|
537 |
+
render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
|
538 |
+
return render_fn
|
539 |
+
|
540 |
+
render_fn = make_render_fn(static_env_params)
|
541 |
+
render_fn_eval = make_render_fn(eval_static_env_params)
|
542 |
+
if config["EVAL_ON_SAMPLED"]:
|
543 |
+
NUM_EVAL_DR_LEVELS = 200
|
544 |
+
key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
|
545 |
+
DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)
|
546 |
+
|
547 |
+
# And the level sampler
|
548 |
+
level_sampler = LevelSampler(
|
549 |
+
capacity=config["level_buffer_capacity"],
|
550 |
+
replay_prob=config["replay_prob"],
|
551 |
+
staleness_coeff=config["staleness_coeff"],
|
552 |
+
minimum_fill_ratio=config["minimum_fill_ratio"],
|
553 |
+
prioritization=config["prioritization"],
|
554 |
+
prioritization_params={"temperature": config["temperature"], "k": config["topk_k"]},
|
555 |
+
duplicate_check=config["buffer_duplicate_check"],
|
556 |
+
)
|
557 |
+
|
558 |
+
@jax.jit
|
559 |
+
def create_train_state(rng) -> TrainState:
|
560 |
+
# Creates the train state
|
561 |
+
def linear_schedule(count):
|
562 |
+
frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / (
|
563 |
+
config["num_updates"] * config["outer_rollout_steps"]
|
564 |
+
)
|
565 |
+
return config["lr"] * frac
|
566 |
+
|
567 |
+
rng, _rng = jax.random.split(rng)
|
568 |
+
init_state = jax.tree.map(lambda x: x[0], sample_random_levels(_rng, 1))
|
569 |
+
|
570 |
+
rng, _rng = jax.random.split(rng)
|
571 |
+
obs, _ = env.reset_to_level(_rng, init_state, env_params)
|
572 |
+
ns = config["num_steps"] * config["outer_rollout_steps"]
|
573 |
+
obs = jax.tree.map(
|
574 |
+
lambda x: jnp.repeat(jnp.repeat(x[None, ...], config["num_train_envs"], axis=0)[None, ...], ns, axis=0),
|
575 |
+
obs,
|
576 |
+
)
|
577 |
+
init_x = (obs, jnp.zeros((ns, config["num_train_envs"]), dtype=jnp.bool_))
|
578 |
+
network = make_network_from_config(env, env_params, config)
|
579 |
+
rng, _rng = jax.random.split(rng)
|
580 |
+
network_params = network.init(_rng, ScannedRNN.initialize_carry(config["num_train_envs"]), init_x)
|
581 |
+
|
582 |
+
if config["anneal_lr"]:
|
583 |
+
tx = optax.chain(
|
584 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
585 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
586 |
+
)
|
587 |
+
else:
|
588 |
+
tx = optax.chain(
|
589 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
590 |
+
optax.adam(config["lr"], eps=1e-5),
|
591 |
+
)
|
592 |
+
|
593 |
+
pholder_level = jax.tree.map(lambda x: x[0], sample_random_levels(jax.random.PRNGKey(0), 1))
|
594 |
+
sampler = level_sampler.initialize(pholder_level, {"max_return": -jnp.inf})
|
595 |
+
pholder_level_batch = jax.tree_util.tree_map(
|
596 |
+
lambda x: jnp.array([x]).repeat(config["num_train_envs"], axis=0), pholder_level
|
597 |
+
)
|
598 |
+
pholder_rollout_batch = (
|
599 |
+
jax.tree.map(
|
600 |
+
lambda x: jnp.repeat(
|
601 |
+
jnp.expand_dims(x, 0), repeats=config["num_steps"] * config["outer_rollout_steps"], axis=0
|
602 |
+
),
|
603 |
+
init_state,
|
604 |
+
),
|
605 |
+
init_x[1][:, 0],
|
606 |
+
)
|
607 |
+
|
608 |
+
pholder_level_batch_scores = jnp.zeros((config["num_train_envs"],), dtype=jnp.float32)
|
609 |
+
train_state = TrainState.create(
|
610 |
+
apply_fn=network.apply,
|
611 |
+
params=network_params,
|
612 |
+
tx=tx,
|
613 |
+
sampler=sampler,
|
614 |
+
update_state=0,
|
615 |
+
num_dr_updates=0,
|
616 |
+
num_replay_updates=0,
|
617 |
+
num_mutation_updates=0,
|
618 |
+
dr_last_level_batch_scores=pholder_level_batch_scores,
|
619 |
+
replay_last_level_batch_scores=pholder_level_batch_scores,
|
620 |
+
mutation_last_level_batch_scores=pholder_level_batch_scores,
|
621 |
+
dr_last_level_batch=pholder_level_batch,
|
622 |
+
replay_last_level_batch=pholder_level_batch,
|
623 |
+
mutation_last_level_batch=pholder_level_batch,
|
624 |
+
dr_last_rollout_batch=pholder_rollout_batch,
|
625 |
+
replay_last_rollout_batch=pholder_rollout_batch,
|
626 |
+
mutation_last_rollout_batch=pholder_rollout_batch,
|
627 |
+
)
|
628 |
+
|
629 |
+
if config["load_from_checkpoint"] != None:
|
630 |
+
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
|
631 |
+
train_state = load_train_state_from_wandb_artifact_path(
|
632 |
+
train_state,
|
633 |
+
config["load_from_checkpoint"],
|
634 |
+
load_only_params=config["load_only_params"],
|
635 |
+
legacy=config["load_legacy_checkpoint"],
|
636 |
+
)
|
637 |
+
return train_state
|
638 |
+
|
639 |
+
all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
|
640 |
+
eval_group_indices = get_eval_level_groups(config["eval_levels"])
|
641 |
+
|
642 |
+
@jax.jit
|
643 |
+
def train_step(carry: Tuple[chex.PRNGKey, TrainState], _):
|
644 |
+
"""
|
645 |
+
This is the main training loop. It basically calls either `on_new_levels`, `on_replay_levels`, or `on_mutate_levels` at every step.
|
646 |
+
"""
|
647 |
+
|
648 |
+
def on_new_levels(rng: chex.PRNGKey, train_state: TrainState):
|
649 |
+
"""
|
650 |
+
Samples new (randomly-generated) levels and evaluates the policy on these. It also then adds the levels to the level buffer if they have high-enough scores.
|
651 |
+
The agent is updated on these trajectories iff `config["exploratory_grad_updates"]` is True.
|
652 |
+
"""
|
653 |
+
sampler = train_state.sampler
|
654 |
+
|
655 |
+
# Reset
|
656 |
+
rng, rng_levels, rng_reset = jax.random.split(rng, 3)
|
657 |
+
new_levels = sample_random_levels(rng_levels, config["num_train_envs"])
|
658 |
+
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
|
659 |
+
jax.random.split(rng_reset, config["num_train_envs"]), new_levels, env_params
|
660 |
+
)
|
661 |
+
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
662 |
+
# Rollout
|
663 |
+
(
|
664 |
+
(rng, train_state, new_hstate, last_obs, last_env_state),
|
665 |
+
(
|
666 |
+
obs,
|
667 |
+
actions,
|
668 |
+
rewards,
|
669 |
+
dones,
|
670 |
+
log_probs,
|
671 |
+
values,
|
672 |
+
info,
|
673 |
+
advantages,
|
674 |
+
targets,
|
675 |
+
losses,
|
676 |
+
grads,
|
677 |
+
rollout_states,
|
678 |
+
),
|
679 |
+
) = sample_trajectories_and_learn(
|
680 |
+
env,
|
681 |
+
env_params,
|
682 |
+
config,
|
683 |
+
rng,
|
684 |
+
train_state,
|
685 |
+
init_hstate,
|
686 |
+
init_obs,
|
687 |
+
init_env_state,
|
688 |
+
update_grad=config["exploratory_grad_updates"],
|
689 |
+
return_states=True,
|
690 |
+
)
|
691 |
+
max_returns = compute_max_returns(dones, rewards)
|
692 |
+
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
|
693 |
+
sampler, _ = level_sampler.insert_batch(sampler, new_levels, scores, {"max_return": max_returns})
|
694 |
+
rng, _rng = jax.random.split(rng)
|
695 |
+
metrics = {
|
696 |
+
"update_state": UpdateState.DR,
|
697 |
+
} | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, new_levels)
|
698 |
+
|
699 |
+
train_state = train_state.replace(
|
700 |
+
sampler=sampler,
|
701 |
+
update_state=UpdateState.DR,
|
702 |
+
num_dr_updates=train_state.num_dr_updates + 1,
|
703 |
+
dr_last_level_batch=new_levels,
|
704 |
+
dr_last_level_batch_scores=scores,
|
705 |
+
dr_last_rollout_batch=jax.tree.map(
|
706 |
+
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
|
707 |
+
),
|
708 |
+
)
|
709 |
+
return (rng, train_state), metrics
|
710 |
+
|
711 |
+
def on_replay_levels(rng: chex.PRNGKey, train_state: TrainState):
|
712 |
+
"""
|
713 |
+
This samples levels from the level buffer, and updates the policy on them.
|
714 |
+
"""
|
715 |
+
sampler = train_state.sampler
|
716 |
+
|
717 |
+
# Collect trajectories on replay levels
|
718 |
+
rng, rng_levels, rng_reset = jax.random.split(rng, 3)
|
719 |
+
sampler, (level_inds, levels) = level_sampler.sample_replay_levels(
|
720 |
+
sampler, rng_levels, config["num_train_envs"]
|
721 |
+
)
|
722 |
+
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
|
723 |
+
jax.random.split(rng_reset, config["num_train_envs"]), levels, env_params
|
724 |
+
)
|
725 |
+
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
726 |
+
(
|
727 |
+
(rng, train_state, new_hstate, last_obs, last_env_state),
|
728 |
+
(
|
729 |
+
obs,
|
730 |
+
actions,
|
731 |
+
rewards,
|
732 |
+
dones,
|
733 |
+
log_probs,
|
734 |
+
values,
|
735 |
+
info,
|
736 |
+
advantages,
|
737 |
+
targets,
|
738 |
+
losses,
|
739 |
+
grads,
|
740 |
+
rollout_states,
|
741 |
+
),
|
742 |
+
) = sample_trajectories_and_learn(
|
743 |
+
env,
|
744 |
+
env_params,
|
745 |
+
config,
|
746 |
+
rng,
|
747 |
+
train_state,
|
748 |
+
init_hstate,
|
749 |
+
init_obs,
|
750 |
+
init_env_state,
|
751 |
+
update_grad=True,
|
752 |
+
return_states=True,
|
753 |
+
)
|
754 |
+
|
755 |
+
max_returns = jnp.maximum(
|
756 |
+
level_sampler.get_levels_extra(sampler, level_inds)["max_return"], compute_max_returns(dones, rewards)
|
757 |
+
)
|
758 |
+
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
|
759 |
+
sampler = level_sampler.update_batch(sampler, level_inds, scores, {"max_return": max_returns})
|
760 |
+
|
761 |
+
rng, _rng = jax.random.split(rng)
|
762 |
+
metrics = {
|
763 |
+
"update_state": UpdateState.REPLAY,
|
764 |
+
} | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, levels)
|
765 |
+
train_state = train_state.replace(
|
766 |
+
sampler=sampler,
|
767 |
+
update_state=UpdateState.REPLAY,
|
768 |
+
num_replay_updates=train_state.num_replay_updates + 1,
|
769 |
+
replay_last_level_batch=levels,
|
770 |
+
replay_last_level_batch_scores=scores,
|
771 |
+
replay_last_rollout_batch=jax.tree.map(
|
772 |
+
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
|
773 |
+
),
|
774 |
+
)
|
775 |
+
return (rng, train_state), metrics
|
776 |
+
|
777 |
+
def on_mutate_levels(rng: chex.PRNGKey, train_state: TrainState):
|
778 |
+
"""
|
779 |
+
This mutates the previous batch of replay levels and potentially adds them to the level buffer.
|
780 |
+
This also updates the policy iff `config["exploratory_grad_updates"]` is True.
|
781 |
+
"""
|
782 |
+
|
783 |
+
sampler = train_state.sampler
|
784 |
+
rng, rng_mutate, rng_reset = jax.random.split(rng, 3)
|
785 |
+
|
786 |
+
# mutate
|
787 |
+
parent_levels = train_state.replay_last_level_batch
|
788 |
+
child_levels = jax.vmap(mutate_world, (0, 0, None))(
|
789 |
+
jax.random.split(rng_mutate, config["num_train_envs"]), parent_levels, config["num_edits"]
|
790 |
+
)
|
791 |
+
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
|
792 |
+
jax.random.split(rng_reset, config["num_train_envs"]), child_levels, env_params
|
793 |
+
)
|
794 |
+
|
795 |
+
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
796 |
+
# rollout
|
797 |
+
(
|
798 |
+
(rng, train_state, new_hstate, last_obs, last_env_state),
|
799 |
+
(
|
800 |
+
obs,
|
801 |
+
actions,
|
802 |
+
rewards,
|
803 |
+
dones,
|
804 |
+
log_probs,
|
805 |
+
values,
|
806 |
+
info,
|
807 |
+
advantages,
|
808 |
+
targets,
|
809 |
+
losses,
|
810 |
+
grads,
|
811 |
+
rollout_states,
|
812 |
+
),
|
813 |
+
) = sample_trajectories_and_learn(
|
814 |
+
env,
|
815 |
+
env_params,
|
816 |
+
config,
|
817 |
+
rng,
|
818 |
+
train_state,
|
819 |
+
init_hstate,
|
820 |
+
init_obs,
|
821 |
+
init_env_state,
|
822 |
+
update_grad=config["exploratory_grad_updates"],
|
823 |
+
return_states=True,
|
824 |
+
)
|
825 |
+
|
826 |
+
max_returns = compute_max_returns(dones, rewards)
|
827 |
+
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
|
828 |
+
sampler, _ = level_sampler.insert_batch(sampler, child_levels, scores, {"max_return": max_returns})
|
829 |
+
|
830 |
+
rng, _rng = jax.random.split(rng)
|
831 |
+
metrics = {"update_state": UpdateState.MUTATE,} | get_all_metrics(
|
832 |
+
_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, child_levels
|
833 |
+
)
|
834 |
+
|
835 |
+
train_state = train_state.replace(
|
836 |
+
sampler=sampler,
|
837 |
+
update_state=UpdateState.DR,
|
838 |
+
num_mutation_updates=train_state.num_mutation_updates + 1,
|
839 |
+
mutation_last_level_batch=child_levels,
|
840 |
+
mutation_last_level_batch_scores=scores,
|
841 |
+
mutation_last_rollout_batch=jax.tree.map(
|
842 |
+
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
|
843 |
+
),
|
844 |
+
)
|
845 |
+
return (rng, train_state), metrics
|
846 |
+
|
847 |
+
rng, train_state = carry
|
848 |
+
rng, rng_replay = jax.random.split(rng)
|
849 |
+
|
850 |
+
# The train step makes a decision on which branch to take, either on_new, on_replay or on_mutate.
|
851 |
+
# on_mutate is only called if the replay branch has been taken before (as it uses `train_state.update_state`).
|
852 |
+
branches = [
|
853 |
+
on_new_levels,
|
854 |
+
on_replay_levels,
|
855 |
+
]
|
856 |
+
if config["use_accel"]:
|
857 |
+
s = train_state.update_state
|
858 |
+
branch = (1 - s) * level_sampler.sample_replay_decision(train_state.sampler, rng_replay) + 2 * s
|
859 |
+
branches.append(on_mutate_levels)
|
860 |
+
else:
|
861 |
+
branch = level_sampler.sample_replay_decision(train_state.sampler, rng_replay).astype(int)
|
862 |
+
|
863 |
+
return jax.lax.switch(branch, branches, rng, train_state)
|
864 |
+
|
865 |
+
@partial(jax.jit, static_argnums=(2,))
|
866 |
+
def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
|
867 |
+
"""
|
868 |
+
This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
|
869 |
+
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
|
870 |
+
"""
|
871 |
+
num_levels = config["num_eval_levels"]
|
872 |
+
return general_eval(
|
873 |
+
rng,
|
874 |
+
eval_env,
|
875 |
+
env_params,
|
876 |
+
train_state,
|
877 |
+
all_eval_levels,
|
878 |
+
env_params.max_timesteps,
|
879 |
+
num_levels,
|
880 |
+
keep_states=keep_states,
|
881 |
+
return_trajectories=True,
|
882 |
+
)
|
883 |
+
|
884 |
+
@partial(jax.jit, static_argnums=(2,))
|
885 |
+
def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
|
886 |
+
return general_eval(
|
887 |
+
rng,
|
888 |
+
env,
|
889 |
+
env_params,
|
890 |
+
train_state,
|
891 |
+
DR_EVAL_LEVELS,
|
892 |
+
env_params.max_timesteps,
|
893 |
+
NUM_EVAL_DR_LEVELS,
|
894 |
+
keep_states=keep_states,
|
895 |
+
)
|
896 |
+
|
897 |
+
@jax.jit
|
898 |
+
def train_and_eval_step(runner_state, _):
|
899 |
+
"""
|
900 |
+
This function runs the train_step for a certain number of iterations, and then evaluates the policy.
|
901 |
+
It returns the updated train state, and a dictionary of metrics.
|
902 |
+
"""
|
903 |
+
# Train
|
904 |
+
(rng, train_state), metrics = jax.lax.scan(train_step, runner_state, None, config["eval_freq"])
|
905 |
+
|
906 |
+
# Eval
|
907 |
+
metrics["update_count"] = (
|
908 |
+
train_state.num_dr_updates + train_state.num_replay_updates + train_state.num_mutation_updates
|
909 |
+
)
|
910 |
+
|
911 |
+
vid_frequency = get_video_frequency(config, metrics["update_count"])
|
912 |
+
should_log_videos = metrics["update_count"] % vid_frequency == 0
|
913 |
+
|
914 |
+
def _compute_eval_learnability(dones, rewards, infos):
|
915 |
+
@jax.vmap
|
916 |
+
def _single(d, r, i):
|
917 |
+
learn, num_eps, num_succ = compute_learnability(config, d, r, i, config["num_eval_levels"])
|
918 |
+
|
919 |
+
return num_eps, num_succ.squeeze(-1)
|
920 |
+
|
921 |
+
num_eps, num_succ = _single(dones, rewards, infos)
|
922 |
+
num_eps, num_succ = num_eps.sum(axis=0), num_succ.sum(axis=0)
|
923 |
+
success_rate = num_succ / jnp.maximum(1, num_eps)
|
924 |
+
|
925 |
+
return success_rate * (1 - success_rate)
|
926 |
+
|
927 |
+
@jax.jit
|
928 |
+
def _get_eval(rng):
|
929 |
+
metrics = {}
|
930 |
+
rng, rng_eval = jax.random.split(rng)
|
931 |
+
(states, cum_rewards, done_idx, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(
|
932 |
+
eval, (0, None)
|
933 |
+
)(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)
|
934 |
+
|
935 |
+
# learnability here of the holdout set:
|
936 |
+
eval_learn = _compute_eval_learnability(eval_dones, eval_rewards, eval_infos)
|
937 |
+
# Collect Metrics
|
938 |
+
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
|
939 |
+
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
|
940 |
+
1, eval_dones.sum(axis=1)
|
941 |
+
)
|
942 |
+
eval_solves = eval_solves.mean(axis=0)
|
943 |
+
metrics["eval_returns"] = eval_returns
|
944 |
+
metrics["eval_ep_lengths"] = episode_lengths.mean(axis=0)
|
945 |
+
metrics["eval_learn"] = eval_learn
|
946 |
+
metrics["eval_solves"] = eval_solves
|
947 |
+
|
948 |
+
metrics["eval_get_max_eplen"] = (episode_lengths == env_params.max_timesteps).mean(axis=0)
|
949 |
+
metrics["episode_return_bigger_than_negative"] = (cum_rewards > -0.4).mean(axis=0)
|
950 |
+
|
951 |
+
if config["EVAL_ON_SAMPLED"]:
|
952 |
+
states_dr, cum_rewards_dr, done_idx_dr, episode_lengths_dr, infos_dr = jax.vmap(
|
953 |
+
eval_on_dr_levels, (0, None)
|
954 |
+
)(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)
|
955 |
+
|
956 |
+
eval_dr_returns = cum_rewards_dr.mean(axis=0).mean()
|
957 |
+
eval_dr_eplen = episode_lengths_dr.mean(axis=0).mean()
|
958 |
+
|
959 |
+
my_eval_dones = infos_dr["returned_episode"]
|
960 |
+
eval_dr_solves = (infos_dr["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
|
961 |
+
1, my_eval_dones.sum(axis=1)
|
962 |
+
)
|
963 |
+
|
964 |
+
metrics["eval_dr_returns"] = eval_dr_returns
|
965 |
+
metrics["eval_dr_eplen"] = eval_dr_eplen
|
966 |
+
metrics["eval_dr_solve_rates"] = eval_dr_solves
|
967 |
+
return metrics, states, episode_lengths, cum_rewards
|
968 |
+
|
969 |
+
@jax.jit
|
970 |
+
def _get_videos(rng, states, episode_lengths, cum_rewards):
|
971 |
+
metrics = {"log_videos": True}
|
972 |
+
|
973 |
+
# just grab the first run
|
974 |
+
states, episode_lengths = jax.tree_util.tree_map(
|
975 |
+
lambda x: x[0], (states, episode_lengths)
|
976 |
+
) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
|
977 |
+
# And one attempt
|
978 |
+
states = jax.tree_util.tree_map(lambda x: x[:, :], states)
|
979 |
+
episode_lengths = episode_lengths[:]
|
980 |
+
images = jax.vmap(jax.vmap(render_fn_eval))(
|
981 |
+
states.env_state.env_state.env_state
|
982 |
+
) # (num_steps, num_eval_levels, ...)
|
983 |
+
frames = images.transpose(
|
984 |
+
0, 1, 4, 2, 3
|
985 |
+
) # WandB expects color channel before image dimensions when dealing with animations for some reason
|
986 |
+
|
987 |
+
@jax.jit
|
988 |
+
def _get_video(rollout_batch):
|
989 |
+
states = rollout_batch[0]
|
990 |
+
images = jax.vmap(render_fn)(states) # dimensions are (steps, x, y, 3)
|
991 |
+
return (
|
992 |
+
# jax.tree.map(lambda x: x[:].transpose(0, 2, 1, 3)[:, ::-1], images).transpose(0, 3, 1, 2),
|
993 |
+
images.transpose(0, 3, 1, 2),
|
994 |
+
# images.transpose(0, 1, 4, 2, 3),
|
995 |
+
rollout_batch[1][:].argmax(),
|
996 |
+
)
|
997 |
+
|
998 |
+
# rollouts
|
999 |
+
metrics["dr_rollout"], metrics["dr_ep_len"] = _get_video(train_state.dr_last_rollout_batch)
|
1000 |
+
metrics["replay_rollout"], metrics["replay_ep_len"] = _get_video(train_state.replay_last_rollout_batch)
|
1001 |
+
metrics["mutation_rollout"], metrics["mutation_ep_len"] = _get_video(
|
1002 |
+
train_state.mutation_last_rollout_batch
|
1003 |
+
)
|
1004 |
+
|
1005 |
+
metrics["eval_animation"] = (frames, episode_lengths)
|
1006 |
+
|
1007 |
+
metrics["eval_returns_video"] = cum_rewards[0]
|
1008 |
+
metrics["eval_len_video"] = episode_lengths
|
1009 |
+
|
1010 |
+
# Eval on sampled
|
1011 |
+
|
1012 |
+
return metrics
|
1013 |
+
|
1014 |
+
@jax.jit
|
1015 |
+
def _get_dummy_videos(rng, states, episode_lengths, cum_rewards):
|
1016 |
+
n_eval = config["num_eval_levels"]
|
1017 |
+
nsteps = env_params.max_timesteps
|
1018 |
+
nsteps2 = config["outer_rollout_steps"] * config["num_steps"]
|
1019 |
+
img_size = (
|
1020 |
+
env.static_env_params.screen_dim[0] // env.static_env_params.downscale,
|
1021 |
+
env.static_env_params.screen_dim[1] // env.static_env_params.downscale,
|
1022 |
+
)
|
1023 |
+
return {
|
1024 |
+
"log_videos": False,
|
1025 |
+
"dr_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
|
1026 |
+
"dr_ep_len": jnp.zeros((), jnp.int32),
|
1027 |
+
"replay_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
|
1028 |
+
"replay_ep_len": jnp.zeros((), jnp.int32),
|
1029 |
+
"mutation_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
|
1030 |
+
"mutation_ep_len": jnp.zeros((), jnp.int32),
|
1031 |
+
# "eval_returns": jnp.zeros((n_eval,), jnp.float32),
|
1032 |
+
# "eval_solves": jnp.zeros((n_eval,), jnp.float32),
|
1033 |
+
# "eval_learn": jnp.zeros((n_eval,), jnp.float32),
|
1034 |
+
# "eval_ep_lengths": jnp.zeros((n_eval,), jnp.int32),
|
1035 |
+
"eval_animation": (
|
1036 |
+
jnp.zeros((nsteps, n_eval, 3, *img_size), jnp.float32),
|
1037 |
+
jnp.zeros((n_eval,), jnp.int32),
|
1038 |
+
),
|
1039 |
+
"eval_returns_video": jnp.zeros((n_eval,), jnp.float32),
|
1040 |
+
"eval_len_video": jnp.zeros((n_eval,), jnp.int32),
|
1041 |
+
}
|
1042 |
+
|
1043 |
+
rng, rng_eval, rng_vid = jax.random.split(rng, 3)
|
1044 |
+
|
1045 |
+
metrics_eval, states, episode_lengths, cum_rewards = _get_eval(rng_eval)
|
1046 |
+
metrics = {
|
1047 |
+
**metrics,
|
1048 |
+
**metrics_eval,
|
1049 |
+
**jax.lax.cond(
|
1050 |
+
should_log_videos, _get_videos, _get_dummy_videos, rng_vid, states, episode_lengths, cum_rewards
|
1051 |
+
),
|
1052 |
+
}
|
1053 |
+
max_num_images = 8
|
1054 |
+
|
1055 |
+
top_regret_ones = max_num_images // 2
|
1056 |
+
bot_regret_ones = max_num_images - top_regret_ones
|
1057 |
+
|
1058 |
+
@jax.jit
|
1059 |
+
def get_values(level_batch, scores):
|
1060 |
+
args = jnp.argsort(scores) # low scores are at the start, high scores are at the end
|
1061 |
+
|
1062 |
+
low_scores = args[:bot_regret_ones]
|
1063 |
+
high_scores = args[-top_regret_ones:]
|
1064 |
+
|
1065 |
+
low_levels = jax.tree.map(lambda x: x[low_scores], level_batch)
|
1066 |
+
high_levels = jax.tree.map(lambda x: x[high_scores], level_batch)
|
1067 |
+
|
1068 |
+
low_scores = scores[low_scores]
|
1069 |
+
high_scores = scores[high_scores]
|
1070 |
+
# now concatenate:
|
1071 |
+
return jax.vmap(render_fn)(
|
1072 |
+
jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), low_levels, high_levels)
|
1073 |
+
), jnp.concatenate([low_scores, high_scores], axis=0)
|
1074 |
+
|
1075 |
+
metrics["dr_levels"], metrics["dr_scores"] = get_values(
|
1076 |
+
train_state.dr_last_level_batch, train_state.dr_last_level_batch_scores
|
1077 |
+
)
|
1078 |
+
metrics["replay_levels"], metrics["replay_scores"] = get_values(
|
1079 |
+
train_state.replay_last_level_batch, train_state.replay_last_level_batch_scores
|
1080 |
+
)
|
1081 |
+
metrics["mutation_levels"], metrics["mutation_scores"] = get_values(
|
1082 |
+
train_state.mutation_last_level_batch, train_state.mutation_last_level_batch_scores
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
def _t(i):
|
1086 |
+
return jax.lax.select(i == 0, config["num_steps"], i)
|
1087 |
+
|
1088 |
+
metrics["dr_ep_len"] = _t(train_state.dr_last_rollout_batch[1][:].argmax())
|
1089 |
+
metrics["replay_ep_len"] = _t(train_state.replay_last_rollout_batch[1][:].argmax())
|
1090 |
+
metrics["mutation_ep_len"] = _t(train_state.mutation_last_rollout_batch[1][:].argmax())
|
1091 |
+
|
1092 |
+
highest_scoring_level = level_sampler.get_levels(train_state.sampler, train_state.sampler["scores"].argmax())
|
1093 |
+
highest_weighted_level = level_sampler.get_levels(
|
1094 |
+
train_state.sampler, level_sampler.level_weights(train_state.sampler).argmax()
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
metrics["highest_scoring_level"] = render_fn(highest_scoring_level)
|
1098 |
+
metrics["highest_weighted_level"] = render_fn(highest_weighted_level)
|
1099 |
+
|
1100 |
+
# log_eval(metrics, train_state_to_log_dict(runner_state[1], level_sampler))
|
1101 |
+
jax.debug.callback(log_eval, metrics, train_state_to_log_dict(runner_state[1], level_sampler))
|
1102 |
+
return (rng, train_state), {"update_count": metrics["update_count"]}
|
1103 |
+
|
1104 |
+
def log_checkpoint(update_count, train_state):
|
1105 |
+
if config["save_path"] is not None and config["checkpoint_save_freq"] > 1:
|
1106 |
+
steps = (
|
1107 |
+
int(update_count)
|
1108 |
+
* int(config["num_train_envs"])
|
1109 |
+
* int(config["num_steps"])
|
1110 |
+
* int(config["outer_rollout_steps"])
|
1111 |
+
)
|
1112 |
+
# save_params_to_wandb(train_state.params, steps, config)
|
1113 |
+
save_model_to_wandb(train_state, steps, config)
|
1114 |
+
|
1115 |
+
def train_eval_and_checkpoint_step(runner_state, _):
|
1116 |
+
runner_state, metrics = jax.lax.scan(
|
1117 |
+
train_and_eval_step, runner_state, xs=jnp.arange(config["checkpoint_save_freq"] // config["eval_freq"])
|
1118 |
+
)
|
1119 |
+
jax.debug.callback(log_checkpoint, metrics["update_count"][-1], runner_state[1])
|
1120 |
+
return runner_state, metrics
|
1121 |
+
|
1122 |
+
# Set up the train states
|
1123 |
+
rng = jax.random.PRNGKey(config["seed"])
|
1124 |
+
rng_init, rng_train = jax.random.split(rng)
|
1125 |
+
|
1126 |
+
train_state = create_train_state(rng_init)
|
1127 |
+
runner_state = (rng_train, train_state)
|
1128 |
+
|
1129 |
+
runner_state, metrics = jax.lax.scan(
|
1130 |
+
train_eval_and_checkpoint_step,
|
1131 |
+
runner_state,
|
1132 |
+
xs=jnp.arange((config["num_updates"]) // (config["checkpoint_save_freq"])),
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
if config["save_path"] is not None:
|
1136 |
+
# save_params_to_wandb(runner_state[1].params, config["total_timesteps"], config)
|
1137 |
+
save_model_to_wandb(runner_state[1], config["total_timesteps"], config, is_final=True)
|
1138 |
+
|
1139 |
+
return runner_state[1]
|
1140 |
+
|
1141 |
+
|
1142 |
+
if __name__ == "__main__":
|
1143 |
+
main()
|
Kinetix/experiments/ppo.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import hydra
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
|
5 |
+
from kinetix.environment.ued.ued import (
|
6 |
+
make_reset_train_function_with_list_of_levels,
|
7 |
+
make_reset_train_function_with_mutations,
|
8 |
+
)
|
9 |
+
from kinetix.render.renderer_pixels import make_render_pixels
|
10 |
+
from kinetix.util.config import (
|
11 |
+
get_video_frequency,
|
12 |
+
init_wandb,
|
13 |
+
normalise_config,
|
14 |
+
generate_params_from_config,
|
15 |
+
)
|
16 |
+
|
17 |
+
os.environ["WANDB_DISABLE_SERVICE"] = "True"
|
18 |
+
|
19 |
+
|
20 |
+
import sys
|
21 |
+
from typing import Any, NamedTuple
|
22 |
+
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
import numpy as np
|
26 |
+
import optax
|
27 |
+
from flax.training.train_state import TrainState
|
28 |
+
|
29 |
+
from kinetix.models import make_network_from_config
|
30 |
+
from kinetix.util.learning import general_eval, get_eval_levels
|
31 |
+
from flax.serialization import to_state_dict
|
32 |
+
|
33 |
+
import wandb
|
34 |
+
from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name
|
35 |
+
from kinetix.environment.wrappers import (
|
36 |
+
AutoReplayWrapper,
|
37 |
+
AutoResetWrapper,
|
38 |
+
BatchEnvWrapper,
|
39 |
+
DenseRewardWrapper,
|
40 |
+
LogWrapper,
|
41 |
+
UnderspecifiedToGymnaxWrapper,
|
42 |
+
)
|
43 |
+
from kinetix.models.actor_critic import ScannedRNN
|
44 |
+
from kinetix.util.saving import (
|
45 |
+
load_train_state_from_wandb_artifact_path,
|
46 |
+
save_model_to_wandb,
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
class Transition(NamedTuple):
|
51 |
+
done: jnp.ndarray
|
52 |
+
action: jnp.ndarray
|
53 |
+
value: jnp.ndarray
|
54 |
+
reward: jnp.ndarray
|
55 |
+
log_prob: jnp.ndarray
|
56 |
+
obs: Any
|
57 |
+
info: jnp.ndarray
|
58 |
+
|
59 |
+
|
60 |
+
def make_train(config, env_params, static_env_params):
|
61 |
+
config["num_updates"] = config["total_timesteps"] // config["num_steps"] // config["num_train_envs"]
|
62 |
+
config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"]
|
63 |
+
|
64 |
+
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
|
65 |
+
|
66 |
+
if config["train_level_mode"] == "list":
|
67 |
+
reset_func = make_reset_train_function_with_list_of_levels(
|
68 |
+
config, config["train_levels_list"], static_env_params, is_loading_train_levels=True
|
69 |
+
)
|
70 |
+
elif config["train_level_mode"] == "random":
|
71 |
+
reset_func = make_reset_train_function_with_mutations(
|
72 |
+
env.physics_engine, env_params, env.static_env_params, config
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
|
76 |
+
|
77 |
+
env = UnderspecifiedToGymnaxWrapper(AutoResetWrapper(env, reset_func))
|
78 |
+
|
79 |
+
eval_env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
|
80 |
+
eval_env = UnderspecifiedToGymnaxWrapper(AutoReplayWrapper(eval_env))
|
81 |
+
|
82 |
+
env = DenseRewardWrapper(env)
|
83 |
+
env = LogWrapper(env)
|
84 |
+
env = BatchEnvWrapper(env, num_envs=config["num_train_envs"])
|
85 |
+
|
86 |
+
eval_env_nonbatch = LogWrapper(DenseRewardWrapper(eval_env))
|
87 |
+
|
88 |
+
def linear_schedule(count):
|
89 |
+
frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / config["num_updates"]
|
90 |
+
return config["lr"] * frac
|
91 |
+
|
92 |
+
def linear_warmup_cosine_decay_schedule(count):
|
93 |
+
frac = (count // (config["num_minibatches"] * config["update_epochs"])) / config[
|
94 |
+
"num_updates"
|
95 |
+
] # between 0 and 1
|
96 |
+
delta = config["peak_lr"] - config["initial_lr"]
|
97 |
+
frac_diff_max = 1.0 - config["warmup_frac"]
|
98 |
+
frac_cosine = (frac - config["warmup_frac"]) / frac_diff_max
|
99 |
+
|
100 |
+
return jax.lax.select(
|
101 |
+
frac < config["warmup_frac"],
|
102 |
+
config["initial_lr"] + delta * frac / config["warmup_frac"],
|
103 |
+
config["peak_lr"] * jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * ((frac_cosine) % 1.0)))),
|
104 |
+
)
|
105 |
+
|
106 |
+
def train(rng):
|
107 |
+
# INIT NETWORK
|
108 |
+
network = make_network_from_config(env, env_params, config)
|
109 |
+
rng, _rng = jax.random.split(rng)
|
110 |
+
obsv, env_state = env.reset(_rng, env_params)
|
111 |
+
dones = jnp.zeros((config["num_train_envs"]), dtype=jnp.bool_)
|
112 |
+
rng, _rng = jax.random.split(rng)
|
113 |
+
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
114 |
+
init_x = jax.tree.map(lambda x: x[None, ...], (obsv, dones))
|
115 |
+
network_params = network.init(_rng, init_hstate, init_x)
|
116 |
+
|
117 |
+
param_count = sum(x.size for x in jax.tree_util.tree_leaves(network_params))
|
118 |
+
obs_size = sum(x.size for x in jax.tree_util.tree_leaves(obsv)) // config["num_train_envs"]
|
119 |
+
|
120 |
+
print("Number of parameters", param_count, "size of obs: ", obs_size)
|
121 |
+
if config["anneal_lr"]:
|
122 |
+
tx = optax.chain(
|
123 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
124 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
125 |
+
)
|
126 |
+
elif config["warmup_lr"]:
|
127 |
+
tx = optax.chain(
|
128 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
129 |
+
optax.adamw(learning_rate=linear_warmup_cosine_decay_schedule, eps=1e-5),
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
tx = optax.chain(
|
133 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
134 |
+
optax.adam(config["lr"], eps=1e-5),
|
135 |
+
)
|
136 |
+
train_state = TrainState.create(
|
137 |
+
apply_fn=network.apply,
|
138 |
+
params=network_params,
|
139 |
+
tx=tx,
|
140 |
+
)
|
141 |
+
if config["load_from_checkpoint"] != None:
|
142 |
+
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
|
143 |
+
train_state = load_train_state_from_wandb_artifact_path(
|
144 |
+
train_state, config["load_from_checkpoint"], load_only_params=config["load_only_params"]
|
145 |
+
)
|
146 |
+
# INIT ENV
|
147 |
+
rng, _rng = jax.random.split(rng)
|
148 |
+
obsv, env_state = env.reset(_rng, env_params)
|
149 |
+
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
150 |
+
render_static_env_params = env.static_env_params.replace(downscale=1)
|
151 |
+
pixel_renderer = jax.jit(make_render_pixels(env_params, render_static_env_params))
|
152 |
+
pixel_render_fn = lambda x: pixel_renderer(x) / 255.0
|
153 |
+
eval_levels = get_eval_levels(config["eval_levels"], env.static_env_params)
|
154 |
+
|
155 |
+
def _vmapped_eval_step(runner_state, rng):
|
156 |
+
def _single_eval_step(rng):
|
157 |
+
return general_eval(
|
158 |
+
rng,
|
159 |
+
eval_env_nonbatch,
|
160 |
+
env_params,
|
161 |
+
runner_state[0],
|
162 |
+
eval_levels,
|
163 |
+
env_params.max_timesteps,
|
164 |
+
config["num_eval_levels"],
|
165 |
+
keep_states=True,
|
166 |
+
return_trajectories=True,
|
167 |
+
)
|
168 |
+
|
169 |
+
(states, returns, done_idxs, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(
|
170 |
+
_single_eval_step
|
171 |
+
)(jax.random.split(rng, config["eval_num_attempts"]))
|
172 |
+
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
|
173 |
+
1, eval_dones.sum(axis=1)
|
174 |
+
)
|
175 |
+
states_to_plot = jax.tree.map(lambda x: x[0], states)
|
176 |
+
# obs = jax.vmap(jax.vmap(pixel_render_fn))(states_to_plot.env_state.env_state.env_state)
|
177 |
+
|
178 |
+
return (
|
179 |
+
states_to_plot,
|
180 |
+
done_idxs[0],
|
181 |
+
returns[0],
|
182 |
+
returns.mean(axis=0),
|
183 |
+
episode_lengths.mean(axis=0),
|
184 |
+
eval_solves.mean(axis=0),
|
185 |
+
)
|
186 |
+
|
187 |
+
# TRAIN LOOP
|
188 |
+
def _update_step(runner_state, unused):
|
189 |
+
# COLLECT TRAJECTORIES
|
190 |
+
def _env_step(runner_state, unused):
|
191 |
+
(
|
192 |
+
train_state,
|
193 |
+
env_state,
|
194 |
+
last_obs,
|
195 |
+
last_done,
|
196 |
+
hstate,
|
197 |
+
rng,
|
198 |
+
update_step,
|
199 |
+
) = runner_state
|
200 |
+
|
201 |
+
# SELECT ACTION
|
202 |
+
rng, _rng = jax.random.split(rng)
|
203 |
+
ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :])
|
204 |
+
hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
|
205 |
+
action = pi.sample(seed=_rng)
|
206 |
+
log_prob = pi.log_prob(action)
|
207 |
+
value, action, log_prob = (
|
208 |
+
value.squeeze(0),
|
209 |
+
action.squeeze(0),
|
210 |
+
log_prob.squeeze(0),
|
211 |
+
)
|
212 |
+
|
213 |
+
# STEP ENV
|
214 |
+
rng, _rng = jax.random.split(rng)
|
215 |
+
obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
|
216 |
+
transition = Transition(last_done, action, value, reward, log_prob, last_obs, info)
|
217 |
+
runner_state = (
|
218 |
+
train_state,
|
219 |
+
env_state,
|
220 |
+
obsv,
|
221 |
+
done,
|
222 |
+
hstate,
|
223 |
+
rng,
|
224 |
+
update_step,
|
225 |
+
)
|
226 |
+
return runner_state, transition
|
227 |
+
|
228 |
+
initial_hstate = runner_state[-3]
|
229 |
+
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"])
|
230 |
+
|
231 |
+
# CALCULATE ADVANTAGE
|
232 |
+
(
|
233 |
+
train_state,
|
234 |
+
env_state,
|
235 |
+
last_obs,
|
236 |
+
last_done,
|
237 |
+
hstate,
|
238 |
+
rng,
|
239 |
+
update_step,
|
240 |
+
) = runner_state
|
241 |
+
ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :])
|
242 |
+
_, _, last_val = network.apply(train_state.params, hstate, ac_in)
|
243 |
+
last_val = last_val.squeeze(0)
|
244 |
+
|
245 |
+
def _calculate_gae(traj_batch, last_val, last_done):
|
246 |
+
def _get_advantages(carry, transition):
|
247 |
+
gae, next_value, next_done = carry
|
248 |
+
done, value, reward = (
|
249 |
+
transition.done,
|
250 |
+
transition.value,
|
251 |
+
transition.reward,
|
252 |
+
)
|
253 |
+
delta = reward + config["gamma"] * next_value * (1 - next_done) - value
|
254 |
+
gae = delta + config["gamma"] * config["gae_lambda"] * (1 - next_done) * gae
|
255 |
+
return (gae, value, done), gae
|
256 |
+
|
257 |
+
_, advantages = jax.lax.scan(
|
258 |
+
_get_advantages,
|
259 |
+
(jnp.zeros_like(last_val), last_val, last_done),
|
260 |
+
traj_batch,
|
261 |
+
reverse=True,
|
262 |
+
unroll=16,
|
263 |
+
)
|
264 |
+
return advantages, advantages + traj_batch.value
|
265 |
+
|
266 |
+
advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
|
267 |
+
|
268 |
+
# UPDATE NETWORK
|
269 |
+
def _update_epoch(update_state, unused):
|
270 |
+
def _update_minbatch(train_state, batch_info):
|
271 |
+
init_hstate, traj_batch, advantages, targets = batch_info
|
272 |
+
|
273 |
+
def _loss_fn(params, init_hstate, traj_batch, gae, targets):
|
274 |
+
# RERUN NETWORK
|
275 |
+
_, pi, value = network.apply(params, init_hstate[0], (traj_batch.obs, traj_batch.done))
|
276 |
+
log_prob = pi.log_prob(traj_batch.action)
|
277 |
+
|
278 |
+
# CALCULATE VALUE LOSS
|
279 |
+
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
|
280 |
+
-config["clip_eps"], config["clip_eps"]
|
281 |
+
)
|
282 |
+
value_losses = jnp.square(value - targets)
|
283 |
+
value_losses_clipped = jnp.square(value_pred_clipped - targets)
|
284 |
+
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
|
285 |
+
|
286 |
+
# CALCULATE ACTOR LOSS
|
287 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
288 |
+
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
|
289 |
+
loss_actor1 = ratio * gae
|
290 |
+
loss_actor2 = (
|
291 |
+
jnp.clip(
|
292 |
+
ratio,
|
293 |
+
1.0 - config["clip_eps"],
|
294 |
+
1.0 + config["clip_eps"],
|
295 |
+
)
|
296 |
+
* gae
|
297 |
+
)
|
298 |
+
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
|
299 |
+
loss_actor = loss_actor.mean()
|
300 |
+
entropy = pi.entropy().mean()
|
301 |
+
|
302 |
+
total_loss = loss_actor + config["vf_coef"] * value_loss - config["ent_coef"] * entropy
|
303 |
+
return total_loss, (value_loss, loss_actor, entropy)
|
304 |
+
|
305 |
+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
306 |
+
total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
|
307 |
+
train_state = train_state.apply_gradients(grads=grads)
|
308 |
+
return train_state, total_loss
|
309 |
+
|
310 |
+
(
|
311 |
+
train_state,
|
312 |
+
init_hstate,
|
313 |
+
traj_batch,
|
314 |
+
advantages,
|
315 |
+
targets,
|
316 |
+
rng,
|
317 |
+
) = update_state
|
318 |
+
rng, _rng = jax.random.split(rng)
|
319 |
+
permutation = jax.random.permutation(_rng, config["num_train_envs"])
|
320 |
+
batch = (init_hstate, traj_batch, advantages, targets)
|
321 |
+
|
322 |
+
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch)
|
323 |
+
|
324 |
+
minibatches = jax.tree_util.tree_map(
|
325 |
+
lambda x: jnp.swapaxes(
|
326 |
+
jnp.reshape(
|
327 |
+
x,
|
328 |
+
[x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]),
|
329 |
+
),
|
330 |
+
1,
|
331 |
+
0,
|
332 |
+
),
|
333 |
+
shuffled_batch,
|
334 |
+
)
|
335 |
+
|
336 |
+
train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
|
337 |
+
update_state = (
|
338 |
+
train_state,
|
339 |
+
init_hstate,
|
340 |
+
traj_batch,
|
341 |
+
advantages,
|
342 |
+
targets,
|
343 |
+
rng,
|
344 |
+
)
|
345 |
+
return update_state, total_loss
|
346 |
+
|
347 |
+
init_hstate = initial_hstate[None, :] # TBH
|
348 |
+
update_state = (
|
349 |
+
train_state,
|
350 |
+
init_hstate,
|
351 |
+
traj_batch,
|
352 |
+
advantages,
|
353 |
+
targets,
|
354 |
+
rng,
|
355 |
+
)
|
356 |
+
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"])
|
357 |
+
train_state = update_state[0]
|
358 |
+
metric = jax.tree.map(
|
359 |
+
lambda x: (x * traj_batch.info["returned_episode"]).sum() / traj_batch.info["returned_episode"].sum(),
|
360 |
+
traj_batch.info,
|
361 |
+
)
|
362 |
+
rng = update_state[-1]
|
363 |
+
|
364 |
+
if config["use_wandb"]:
|
365 |
+
vid_frequency = get_video_frequency(config, update_step)
|
366 |
+
rng, _rng = jax.random.split(rng)
|
367 |
+
to_log_videos = _vmapped_eval_step(runner_state, _rng)
|
368 |
+
should_log_videos = update_step % vid_frequency == 0
|
369 |
+
first = jax.lax.cond(
|
370 |
+
should_log_videos,
|
371 |
+
lambda: jax.vmap(jax.vmap(pixel_render_fn))(to_log_videos[0].env_state.env_state.env_state),
|
372 |
+
lambda: (
|
373 |
+
jnp.zeros(
|
374 |
+
(
|
375 |
+
env_params.max_timesteps,
|
376 |
+
config["num_eval_levels"],
|
377 |
+
*PixelObservations(env_params, render_static_env_params)
|
378 |
+
.observation_space(env_params)
|
379 |
+
.shape,
|
380 |
+
)
|
381 |
+
)
|
382 |
+
),
|
383 |
+
)
|
384 |
+
to_log_videos = (first, should_log_videos, *to_log_videos[1:])
|
385 |
+
|
386 |
+
def callback(metric, raw_info, loss_info, update_step, to_log_videos):
|
387 |
+
to_log = {}
|
388 |
+
to_log["timing/num_updates"] = update_step
|
389 |
+
to_log["timing/num_env_steps"] = update_step * config["num_steps"] * config["num_train_envs"]
|
390 |
+
(
|
391 |
+
obs_vid,
|
392 |
+
should_log_videos,
|
393 |
+
idx_vid,
|
394 |
+
eval_return_vid,
|
395 |
+
eval_return_mean,
|
396 |
+
eval_eplen_mean,
|
397 |
+
eval_solverate_mean,
|
398 |
+
) = to_log_videos
|
399 |
+
to_log["eval/mean_eval_return"] = eval_return_mean.mean()
|
400 |
+
to_log["eval/mean_eval_eplen"] = eval_eplen_mean.mean()
|
401 |
+
for i, eval_name in enumerate(config["eval_levels"]):
|
402 |
+
return_on_video = eval_return_vid[i]
|
403 |
+
to_log[f"eval_video/return_{eval_name}"] = return_on_video
|
404 |
+
to_log[f"eval_video/len_{eval_name}"] = idx_vid[i]
|
405 |
+
to_log[f"eval_avg/return_{eval_name}"] = eval_return_mean[i]
|
406 |
+
to_log[f"eval_avg/solve_rate_{eval_name}"] = eval_solverate_mean[i]
|
407 |
+
|
408 |
+
if should_log_videos:
|
409 |
+
for i, eval_name in enumerate(config["eval_levels"]):
|
410 |
+
obs_to_use = obs_vid[: idx_vid[i], i]
|
411 |
+
obs_to_use = np.asarray(obs_to_use).transpose(0, 3, 2, 1)[:, :, ::-1, :]
|
412 |
+
to_log[f"media/eval_video_{eval_name}"] = wandb.Video((obs_to_use * 255).astype(np.uint8))
|
413 |
+
|
414 |
+
wandb.log(to_log)
|
415 |
+
|
416 |
+
jax.debug.callback(callback, metric, traj_batch.info, loss_info, update_step, to_log_videos)
|
417 |
+
|
418 |
+
runner_state = (
|
419 |
+
train_state,
|
420 |
+
env_state,
|
421 |
+
last_obs,
|
422 |
+
last_done,
|
423 |
+
hstate,
|
424 |
+
rng,
|
425 |
+
update_step + 1,
|
426 |
+
)
|
427 |
+
return runner_state, metric
|
428 |
+
|
429 |
+
rng, _rng = jax.random.split(rng)
|
430 |
+
runner_state = (
|
431 |
+
train_state,
|
432 |
+
env_state,
|
433 |
+
obsv,
|
434 |
+
jnp.zeros((config["num_train_envs"]), dtype=bool),
|
435 |
+
init_hstate,
|
436 |
+
_rng,
|
437 |
+
0,
|
438 |
+
)
|
439 |
+
runner_state, metric = jax.lax.scan(_update_step, runner_state, None, config["num_updates"])
|
440 |
+
return {"runner_state": runner_state, "metric": metric}
|
441 |
+
|
442 |
+
return train
|
443 |
+
|
444 |
+
|
445 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="ppo")
|
446 |
+
def main(config):
|
447 |
+
config = normalise_config(OmegaConf.to_container(config), "PPO")
|
448 |
+
env_params, static_env_params = generate_params_from_config(config)
|
449 |
+
config["env_params"] = to_state_dict(env_params)
|
450 |
+
config["static_env_params"] = to_state_dict(static_env_params)
|
451 |
+
|
452 |
+
if config["use_wandb"]:
|
453 |
+
run = init_wandb(config, "PPO")
|
454 |
+
|
455 |
+
rng = jax.random.PRNGKey(config["seed"])
|
456 |
+
rng, _rng = jax.random.split(rng)
|
457 |
+
train_jit = jax.jit(make_train(config, env_params, static_env_params))
|
458 |
+
|
459 |
+
out = train_jit(_rng)
|
460 |
+
|
461 |
+
if config["use_wandb"]:
|
462 |
+
if config["save_policy"]:
|
463 |
+
train_state = jax.tree.map(lambda x: x, out["runner_state"][0])
|
464 |
+
save_model_to_wandb(train_state, config["total_timesteps"], config)
|
465 |
+
|
466 |
+
|
467 |
+
if __name__ == "__main__":
|
468 |
+
main()
|
Kinetix/experiments/sfl.py
ADDED
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Based on PureJaxRL Implementation of PPO
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
import typing
|
9 |
+
from functools import partial
|
10 |
+
from typing import NamedTuple
|
11 |
+
|
12 |
+
import chex
|
13 |
+
import hydra
|
14 |
+
import jax
|
15 |
+
import jax.experimental
|
16 |
+
import jax.numpy as jnp
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import numpy as np
|
19 |
+
import optax
|
20 |
+
from flax.training.train_state import TrainState
|
21 |
+
from kinetix.environment.ued.ued import make_reset_train_function_with_mutations, make_vmapped_filtered_level_sampler
|
22 |
+
from kinetix.environment.ued.ued import (
|
23 |
+
make_reset_train_function_with_list_of_levels,
|
24 |
+
make_reset_train_function_with_mutations,
|
25 |
+
)
|
26 |
+
from kinetix.util.config import (
|
27 |
+
generate_ued_params_from_config,
|
28 |
+
init_wandb,
|
29 |
+
normalise_config,
|
30 |
+
generate_params_from_config,
|
31 |
+
get_eval_level_groups,
|
32 |
+
)
|
33 |
+
from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv
|
34 |
+
from omegaconf import OmegaConf
|
35 |
+
from PIL import Image
|
36 |
+
from flax.serialization import to_state_dict
|
37 |
+
|
38 |
+
import wandb
|
39 |
+
from kinetix.environment.env import make_kinetix_env_from_name
|
40 |
+
from kinetix.environment.wrappers import (
|
41 |
+
AutoReplayWrapper,
|
42 |
+
DenseRewardWrapper,
|
43 |
+
LogWrapper,
|
44 |
+
UnderspecifiedToGymnaxWrapper,
|
45 |
+
)
|
46 |
+
from kinetix.models import make_network_from_config
|
47 |
+
from kinetix.models.actor_critic import ScannedRNN
|
48 |
+
from kinetix.render.renderer_pixels import make_render_pixels
|
49 |
+
from kinetix.util.learning import general_eval, get_eval_levels
|
50 |
+
from kinetix.util.saving import (
|
51 |
+
load_train_state_from_wandb_artifact_path,
|
52 |
+
save_model_to_wandb,
|
53 |
+
)
|
54 |
+
|
55 |
+
sys.path.append("ued")
|
56 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
57 |
+
from safetensors.flax import load_file, save_file
|
58 |
+
|
59 |
+
|
60 |
+
def save_params(params: typing.Dict, filename: typing.Union[str, os.PathLike]) -> None:
|
61 |
+
flattened_dict = flatten_dict(params, sep=",")
|
62 |
+
save_file(flattened_dict, filename)
|
63 |
+
|
64 |
+
|
65 |
+
def load_params(filename: typing.Union[str, os.PathLike]) -> typing.Dict:
|
66 |
+
flattened_dict = load_file(filename)
|
67 |
+
return unflatten_dict(flattened_dict, sep=",")
|
68 |
+
|
69 |
+
|
70 |
+
class Transition(NamedTuple):
|
71 |
+
global_done: jnp.ndarray
|
72 |
+
done: jnp.ndarray
|
73 |
+
action: jnp.ndarray
|
74 |
+
value: jnp.ndarray
|
75 |
+
reward: jnp.ndarray
|
76 |
+
log_prob: jnp.ndarray
|
77 |
+
obs: jnp.ndarray
|
78 |
+
info: jnp.ndarray
|
79 |
+
|
80 |
+
|
81 |
+
class RolloutBatch(NamedTuple):
|
82 |
+
obs: jnp.ndarray
|
83 |
+
actions: jnp.ndarray
|
84 |
+
rewards: jnp.ndarray
|
85 |
+
dones: jnp.ndarray
|
86 |
+
log_probs: jnp.ndarray
|
87 |
+
values: jnp.ndarray
|
88 |
+
targets: jnp.ndarray
|
89 |
+
advantages: jnp.ndarray
|
90 |
+
# carry: jnp.ndarray
|
91 |
+
mask: jnp.ndarray
|
92 |
+
|
93 |
+
|
94 |
+
def evaluate_rnn(
|
95 |
+
rng: chex.PRNGKey,
|
96 |
+
env: UnderspecifiedEnv,
|
97 |
+
env_params: EnvParams,
|
98 |
+
train_state: TrainState,
|
99 |
+
init_hstate: chex.ArrayTree,
|
100 |
+
init_obs: Observation,
|
101 |
+
init_env_state: EnvState,
|
102 |
+
max_episode_length: int,
|
103 |
+
keep_states=True,
|
104 |
+
) -> tuple[chex.Array, chex.Array, chex.Array]:
|
105 |
+
"""This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths)
|
106 |
+
|
107 |
+
Args:
|
108 |
+
rng (chex.PRNGKey):
|
109 |
+
env (UnderspecifiedEnv):
|
110 |
+
env_params (EnvParams):
|
111 |
+
train_state (TrainState):
|
112 |
+
init_hstate (chex.ArrayTree): Shape (num_levels, )
|
113 |
+
init_obs (Observation): Shape (num_levels, )
|
114 |
+
init_env_state (EnvState): Shape (num_levels, )
|
115 |
+
max_episode_length (int):
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,)
|
119 |
+
"""
|
120 |
+
num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0]
|
121 |
+
|
122 |
+
def step(carry, _):
|
123 |
+
rng, hstate, obs, state, done, mask, episode_length = carry
|
124 |
+
rng, rng_action, rng_step = jax.random.split(rng, 3)
|
125 |
+
|
126 |
+
x = jax.tree.map(lambda x: x[None, ...], (obs, done))
|
127 |
+
hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x)
|
128 |
+
action = pi.sample(seed=rng_action).squeeze(0)
|
129 |
+
|
130 |
+
obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
131 |
+
jax.random.split(rng_step, num_levels), state, action, env_params
|
132 |
+
)
|
133 |
+
|
134 |
+
next_mask = mask & ~done
|
135 |
+
episode_length += mask
|
136 |
+
|
137 |
+
if keep_states:
|
138 |
+
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, info)
|
139 |
+
else:
|
140 |
+
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, info)
|
141 |
+
|
142 |
+
(_, _, _, _, _, _, episode_lengths), (states, rewards, infos) = jax.lax.scan(
|
143 |
+
step,
|
144 |
+
(
|
145 |
+
rng,
|
146 |
+
init_hstate,
|
147 |
+
init_obs,
|
148 |
+
init_env_state,
|
149 |
+
jnp.zeros(num_levels, dtype=bool),
|
150 |
+
jnp.ones(num_levels, dtype=bool),
|
151 |
+
jnp.zeros(num_levels, dtype=jnp.int32),
|
152 |
+
),
|
153 |
+
None,
|
154 |
+
length=max_episode_length,
|
155 |
+
)
|
156 |
+
|
157 |
+
return states, rewards, episode_lengths, infos
|
158 |
+
|
159 |
+
|
160 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="sfl")
|
161 |
+
def main(config):
|
162 |
+
time_start = time.time()
|
163 |
+
config = OmegaConf.to_container(config)
|
164 |
+
config = normalise_config(config, "SFL" if config["ued"]["sampled_envs_ratio"] > 0 else "SFL-DR")
|
165 |
+
env_params, static_env_params = generate_params_from_config(config)
|
166 |
+
config["env_params"] = to_state_dict(env_params)
|
167 |
+
config["static_env_params"] = to_state_dict(static_env_params)
|
168 |
+
run = init_wandb(config, "SFL")
|
169 |
+
|
170 |
+
rng = jax.random.PRNGKey(config["seed"])
|
171 |
+
|
172 |
+
config["num_envs_from_sampled"] = int(config["num_train_envs"] * config["sampled_envs_ratio"])
|
173 |
+
config["num_envs_to_generate"] = int(config["num_train_envs"] * (1 - config["sampled_envs_ratio"]))
|
174 |
+
assert (config["num_envs_from_sampled"] + config["num_envs_to_generate"]) == config["num_train_envs"]
|
175 |
+
|
176 |
+
def make_env(static_env_params):
|
177 |
+
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
|
178 |
+
env = AutoReplayWrapper(env)
|
179 |
+
env = UnderspecifiedToGymnaxWrapper(env)
|
180 |
+
env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
|
181 |
+
env = LogWrapper(env)
|
182 |
+
return env
|
183 |
+
|
184 |
+
env = make_env(static_env_params)
|
185 |
+
|
186 |
+
if config["train_level_mode"] == "list":
|
187 |
+
sample_random_level = make_reset_train_function_with_list_of_levels(
|
188 |
+
config, config["train_levels"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
|
189 |
+
)
|
190 |
+
elif config["train_level_mode"] == "random":
|
191 |
+
sample_random_level = make_reset_train_function_with_mutations(
|
192 |
+
env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
|
196 |
+
|
197 |
+
sample_random_levels = make_vmapped_filtered_level_sampler(
|
198 |
+
sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
|
199 |
+
)
|
200 |
+
_, eval_static_env_params = generate_params_from_config(
|
201 |
+
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
|
202 |
+
)
|
203 |
+
eval_env = make_env(eval_static_env_params)
|
204 |
+
ued_params = generate_ued_params_from_config(config)
|
205 |
+
|
206 |
+
def make_render_fn(static_env_params):
|
207 |
+
render_fn_inner = make_render_pixels(env_params, static_env_params)
|
208 |
+
render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
|
209 |
+
return render_fn
|
210 |
+
|
211 |
+
render_fn = make_render_fn(static_env_params)
|
212 |
+
render_fn_eval = make_render_fn(eval_static_env_params)
|
213 |
+
|
214 |
+
NUM_EVAL_DR_LEVELS = 200
|
215 |
+
key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
|
216 |
+
DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)
|
217 |
+
|
218 |
+
print("Hello here num steps is ", config["num_steps"])
|
219 |
+
print("CONFIG is ", config)
|
220 |
+
|
221 |
+
config["total_timesteps"] = config["num_updates"] * config["num_steps"] * config["num_train_envs"]
|
222 |
+
config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"]
|
223 |
+
config["clip_eps"] = config["clip_eps"]
|
224 |
+
|
225 |
+
config["env_name"] = config["env_name"]
|
226 |
+
network = make_network_from_config(env, env_params, config)
|
227 |
+
|
228 |
+
def linear_schedule(count):
|
229 |
+
count = count // (config["num_minibatches"] * config["update_epochs"])
|
230 |
+
frac = 1.0 - count / config["num_updates"]
|
231 |
+
return config["lr"] * frac
|
232 |
+
|
233 |
+
# INIT NETWORK
|
234 |
+
rng, _rng = jax.random.split(rng)
|
235 |
+
train_envs = 32 # To not run out of memory, the initial sample size does not matter.
|
236 |
+
obs, _ = env.reset_to_level(rng, sample_random_level(rng), env_params)
|
237 |
+
obs = jax.tree.map(
|
238 |
+
lambda x: jnp.repeat(jnp.repeat(x[None, ...], train_envs, axis=0)[None, ...], 256, axis=0),
|
239 |
+
obs,
|
240 |
+
)
|
241 |
+
init_x = (obs, jnp.zeros((256, train_envs)))
|
242 |
+
init_hstate = ScannedRNN.initialize_carry(train_envs)
|
243 |
+
network_params = network.init(_rng, init_hstate, init_x)
|
244 |
+
if config["anneal_lr"]:
|
245 |
+
tx = optax.chain(
|
246 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
247 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
tx = optax.chain(
|
251 |
+
optax.clip_by_global_norm(config["max_grad_norm"]),
|
252 |
+
optax.adam(config["lr"], eps=1e-5),
|
253 |
+
)
|
254 |
+
train_state = TrainState.create(
|
255 |
+
apply_fn=network.apply,
|
256 |
+
params=network_params,
|
257 |
+
tx=tx,
|
258 |
+
)
|
259 |
+
if config["load_from_checkpoint"] != None:
|
260 |
+
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
|
261 |
+
train_state = load_train_state_from_wandb_artifact_path(
|
262 |
+
train_state,
|
263 |
+
config["load_from_checkpoint"],
|
264 |
+
load_only_params=config["load_only_params"],
|
265 |
+
legacy=config["load_legacy_checkpoint"],
|
266 |
+
)
|
267 |
+
|
268 |
+
rng, _rng = jax.random.split(rng)
|
269 |
+
|
270 |
+
# INIT ENV
|
271 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
272 |
+
rng_reset = jax.random.split(_rng, config["num_train_envs"])
|
273 |
+
|
274 |
+
new_levels = sample_random_levels(_rng2, config["num_train_envs"])
|
275 |
+
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
|
276 |
+
|
277 |
+
start_state = env_state
|
278 |
+
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
279 |
+
|
280 |
+
@jax.jit
|
281 |
+
def log_buffer_learnability(rng, train_state, instances):
|
282 |
+
BATCH_SIZE = config["num_to_save"]
|
283 |
+
BATCH_ACTORS = BATCH_SIZE
|
284 |
+
|
285 |
+
def _batch_step(unused, rng):
|
286 |
+
def _env_step(runner_state, unused):
|
287 |
+
env_state, start_state, last_obs, last_done, hstate, rng = runner_state
|
288 |
+
|
289 |
+
# SELECT ACTION
|
290 |
+
rng, _rng = jax.random.split(rng)
|
291 |
+
obs_batch = last_obs
|
292 |
+
ac_in = (
|
293 |
+
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
|
294 |
+
last_done[np.newaxis, :],
|
295 |
+
)
|
296 |
+
hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
|
297 |
+
action = pi.sample(seed=_rng).squeeze()
|
298 |
+
log_prob = pi.log_prob(action)
|
299 |
+
env_act = action
|
300 |
+
|
301 |
+
# STEP ENV
|
302 |
+
rng, _rng = jax.random.split(rng)
|
303 |
+
rng_step = jax.random.split(_rng, config["num_to_save"])
|
304 |
+
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
305 |
+
rng_step, env_state, env_act, env_params
|
306 |
+
)
|
307 |
+
done_batch = done
|
308 |
+
|
309 |
+
transition = Transition(
|
310 |
+
done,
|
311 |
+
last_done,
|
312 |
+
action.squeeze(),
|
313 |
+
value.squeeze(),
|
314 |
+
reward,
|
315 |
+
log_prob.squeeze(),
|
316 |
+
obs_batch,
|
317 |
+
info,
|
318 |
+
)
|
319 |
+
runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
|
320 |
+
return runner_state, transition
|
321 |
+
|
322 |
+
@partial(jax.vmap, in_axes=(None, 1, 1, 1))
|
323 |
+
@partial(jax.jit, static_argnums=(0,))
|
324 |
+
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
|
325 |
+
idxs = jnp.arange(max_steps)
|
326 |
+
|
327 |
+
@partial(jax.vmap, in_axes=(0, 0))
|
328 |
+
def __ep_outcomes(start_idx, end_idx):
|
329 |
+
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
|
330 |
+
r = jnp.sum(returns * mask)
|
331 |
+
goal_r = info["GoalR"] # (returns > 0) * 1.0
|
332 |
+
success = jnp.sum(goal_r * mask)
|
333 |
+
l = end_idx - start_idx
|
334 |
+
return r, success, l
|
335 |
+
|
336 |
+
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
|
337 |
+
mask_done = jnp.where(done_idxs == max_steps, 0, 1)
|
338 |
+
ep_return, success, length = __ep_outcomes(
|
339 |
+
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
|
340 |
+
)
|
341 |
+
|
342 |
+
return {
|
343 |
+
"ep_return": ep_return.mean(where=mask_done),
|
344 |
+
"num_episodes": mask_done.sum(),
|
345 |
+
"success_rate": success.mean(where=mask_done),
|
346 |
+
"ep_len": length.mean(where=mask_done),
|
347 |
+
}
|
348 |
+
|
349 |
+
# sample envs
|
350 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
351 |
+
rng_reset = jax.random.split(_rng, config["num_to_save"])
|
352 |
+
rng_levels = jax.random.split(_rng2, config["num_to_save"])
|
353 |
+
# obsv, env_state = jax.vmap(sample_random_level, in_axes=(0,))(reset_rng)
|
354 |
+
# new_levels = jax.vmap(sample_random_level)(rng_levels)
|
355 |
+
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, instances, env_params)
|
356 |
+
# env_instances = new_levels
|
357 |
+
init_hstate = ScannedRNN.initialize_carry(
|
358 |
+
BATCH_ACTORS,
|
359 |
+
)
|
360 |
+
|
361 |
+
runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
|
362 |
+
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
|
363 |
+
done_by_env = traj_batch.done.reshape((-1, config["num_to_save"]))
|
364 |
+
reward_by_env = traj_batch.reward.reshape((-1, config["num_to_save"]))
|
365 |
+
# info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
|
366 |
+
o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
|
367 |
+
success_by_env = o["success_rate"].reshape((1, config["num_to_save"]))
|
368 |
+
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
|
369 |
+
return None, (learnability_by_env, success_by_env.sum(axis=0))
|
370 |
+
|
371 |
+
rngs = jax.random.split(rng, 1)
|
372 |
+
_, (learnability, success_by_env) = jax.lax.scan(_batch_step, None, rngs, 1)
|
373 |
+
return learnability[0], success_by_env[0]
|
374 |
+
|
375 |
+
num_eval_levels = len(config["eval_levels"])
|
376 |
+
all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
|
377 |
+
|
378 |
+
eval_group_indices = get_eval_level_groups(config["eval_levels"])
|
379 |
+
print("group indices", eval_group_indices)
|
380 |
+
|
381 |
+
@jax.jit
|
382 |
+
def get_learnability_set(rng, network_params):
|
383 |
+
|
384 |
+
BATCH_ACTORS = config["batch_size"]
|
385 |
+
|
386 |
+
def _batch_step(unused, rng):
|
387 |
+
def _env_step(runner_state, unused):
|
388 |
+
env_state, start_state, last_obs, last_done, hstate, rng = runner_state
|
389 |
+
|
390 |
+
# SELECT ACTION
|
391 |
+
rng, _rng = jax.random.split(rng)
|
392 |
+
obs_batch = last_obs
|
393 |
+
ac_in = (
|
394 |
+
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
|
395 |
+
last_done[np.newaxis, :],
|
396 |
+
)
|
397 |
+
hstate, pi, value = network.apply(network_params, hstate, ac_in)
|
398 |
+
action = pi.sample(seed=_rng).squeeze()
|
399 |
+
log_prob = pi.log_prob(action)
|
400 |
+
env_act = action
|
401 |
+
|
402 |
+
# STEP ENV
|
403 |
+
rng, _rng = jax.random.split(rng)
|
404 |
+
rng_step = jax.random.split(_rng, config["batch_size"])
|
405 |
+
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
406 |
+
rng_step, env_state, env_act, env_params
|
407 |
+
)
|
408 |
+
done_batch = done
|
409 |
+
|
410 |
+
transition = Transition(
|
411 |
+
done,
|
412 |
+
last_done,
|
413 |
+
action.squeeze(),
|
414 |
+
value.squeeze(),
|
415 |
+
reward,
|
416 |
+
log_prob.squeeze(),
|
417 |
+
obs_batch,
|
418 |
+
info,
|
419 |
+
)
|
420 |
+
runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
|
421 |
+
return runner_state, transition
|
422 |
+
|
423 |
+
@partial(jax.vmap, in_axes=(None, 1, 1, 1))
|
424 |
+
@partial(jax.jit, static_argnums=(0,))
|
425 |
+
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
|
426 |
+
idxs = jnp.arange(max_steps)
|
427 |
+
|
428 |
+
@partial(jax.vmap, in_axes=(0, 0))
|
429 |
+
def __ep_outcomes(start_idx, end_idx):
|
430 |
+
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
|
431 |
+
r = jnp.sum(returns * mask)
|
432 |
+
goal_r = info["GoalR"] # (returns > 0) * 1.0
|
433 |
+
success = jnp.sum(goal_r * mask)
|
434 |
+
l = end_idx - start_idx
|
435 |
+
return r, success, l
|
436 |
+
|
437 |
+
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
|
438 |
+
mask_done = jnp.where(done_idxs == max_steps, 0, 1)
|
439 |
+
ep_return, success, length = __ep_outcomes(
|
440 |
+
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
|
441 |
+
)
|
442 |
+
|
443 |
+
return {
|
444 |
+
"ep_return": ep_return.mean(where=mask_done),
|
445 |
+
"num_episodes": mask_done.sum(),
|
446 |
+
"success_rate": success.mean(where=mask_done),
|
447 |
+
"ep_len": length.mean(where=mask_done),
|
448 |
+
}
|
449 |
+
|
450 |
+
# sample envs
|
451 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
452 |
+
rng_reset = jax.random.split(_rng, config["batch_size"])
|
453 |
+
new_levels = sample_random_levels(_rng2, config["batch_size"])
|
454 |
+
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
|
455 |
+
env_instances = new_levels
|
456 |
+
init_hstate = ScannedRNN.initialize_carry(
|
457 |
+
BATCH_ACTORS,
|
458 |
+
)
|
459 |
+
|
460 |
+
runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
|
461 |
+
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
|
462 |
+
done_by_env = traj_batch.done.reshape((-1, config["batch_size"]))
|
463 |
+
reward_by_env = traj_batch.reward.reshape((-1, config["batch_size"]))
|
464 |
+
# info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
|
465 |
+
o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
|
466 |
+
success_by_env = o["success_rate"].reshape((1, config["batch_size"]))
|
467 |
+
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
|
468 |
+
return None, (learnability_by_env, success_by_env.sum(axis=0), env_instances)
|
469 |
+
|
470 |
+
if config["sampled_envs_ratio"] == 0.0:
|
471 |
+
print("Not doing any rollouts because sampled_envs_ratio is 0.0")
|
472 |
+
# Here we have zero envs, so we can literally just sample random ones because there is no point.
|
473 |
+
top_instances = sample_random_levels(_rng, config["num_to_save"])
|
474 |
+
top_success = top_learn = learnability = success_rates = jnp.zeros(config["num_to_save"])
|
475 |
+
else:
|
476 |
+
rngs = jax.random.split(rng, config["num_batches"])
|
477 |
+
_, (learnability, success_rates, env_instances) = jax.lax.scan(
|
478 |
+
_batch_step, None, rngs, config["num_batches"]
|
479 |
+
)
|
480 |
+
|
481 |
+
flat_env_instances = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), env_instances)
|
482 |
+
learnability = learnability.flatten() + success_rates.flatten() * 0.001
|
483 |
+
top_1000 = jnp.argsort(learnability)[-config["num_to_save"] :]
|
484 |
+
|
485 |
+
top_1000_instances = jax.tree.map(lambda x: x.at[top_1000].get(), flat_env_instances)
|
486 |
+
top_learn, top_instances = learnability.at[top_1000].get(), top_1000_instances
|
487 |
+
top_success = success_rates.at[top_1000].get()
|
488 |
+
|
489 |
+
if config["put_eval_levels_in_buffer"]:
|
490 |
+
top_instances = jax.tree.map(
|
491 |
+
lambda all, new: jnp.concatenate([all[:-num_eval_levels], new], axis=0),
|
492 |
+
top_instances,
|
493 |
+
all_eval_levels.env_state,
|
494 |
+
)
|
495 |
+
|
496 |
+
log = {
|
497 |
+
"learnability/learnability_sampled_mean": learnability.mean(),
|
498 |
+
"learnability/learnability_sampled_median": jnp.median(learnability),
|
499 |
+
"learnability/learnability_sampled_min": learnability.min(),
|
500 |
+
"learnability/learnability_sampled_max": learnability.max(),
|
501 |
+
"learnability/learnability_selected_mean": top_learn.mean(),
|
502 |
+
"learnability/learnability_selected_median": jnp.median(top_learn),
|
503 |
+
"learnability/learnability_selected_min": top_learn.min(),
|
504 |
+
"learnability/learnability_selected_max": top_learn.max(),
|
505 |
+
"learnability/solve_rate_sampled_mean": top_success.mean(),
|
506 |
+
"learnability/solve_rate_sampled_median": jnp.median(top_success),
|
507 |
+
"learnability/solve_rate_sampled_min": top_success.min(),
|
508 |
+
"learnability/solve_rate_sampled_max": top_success.max(),
|
509 |
+
"learnability/solve_rate_selected_mean": success_rates.mean(),
|
510 |
+
"learnability/solve_rate_selected_median": jnp.median(success_rates),
|
511 |
+
"learnability/solve_rate_selected_min": success_rates.min(),
|
512 |
+
"learnability/solve_rate_selected_max": success_rates.max(),
|
513 |
+
}
|
514 |
+
|
515 |
+
return top_learn, top_instances, log
|
516 |
+
|
517 |
+
def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
|
518 |
+
"""
|
519 |
+
This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
|
520 |
+
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
|
521 |
+
"""
|
522 |
+
num_levels = len(config["eval_levels"])
|
523 |
+
# eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
|
524 |
+
return general_eval(
|
525 |
+
rng,
|
526 |
+
eval_env,
|
527 |
+
env_params,
|
528 |
+
train_state,
|
529 |
+
all_eval_levels,
|
530 |
+
env_params.max_timesteps,
|
531 |
+
num_levels,
|
532 |
+
keep_states=keep_states,
|
533 |
+
return_trajectories=True,
|
534 |
+
)
|
535 |
+
|
536 |
+
def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
|
537 |
+
return general_eval(
|
538 |
+
rng,
|
539 |
+
env,
|
540 |
+
env_params,
|
541 |
+
train_state,
|
542 |
+
DR_EVAL_LEVELS,
|
543 |
+
env_params.max_timesteps,
|
544 |
+
NUM_EVAL_DR_LEVELS,
|
545 |
+
keep_states=keep_states,
|
546 |
+
)
|
547 |
+
|
548 |
+
def eval_on_top_learnable_levels(rng: chex.PRNGKey, train_state: TrainState, levels, keep_states=True):
|
549 |
+
N = 5
|
550 |
+
return general_eval(
|
551 |
+
rng,
|
552 |
+
env,
|
553 |
+
env_params,
|
554 |
+
train_state,
|
555 |
+
jax.tree.map(lambda x: x[:N], levels),
|
556 |
+
env_params.max_timesteps,
|
557 |
+
N,
|
558 |
+
keep_states=keep_states,
|
559 |
+
)
|
560 |
+
|
561 |
+
# TRAIN LOOP
|
562 |
+
def train_step(runner_state_instances, unused):
|
563 |
+
# COLLECT TRAJECTORIES
|
564 |
+
runner_state, instances = runner_state_instances
|
565 |
+
num_env_instances = instances.polygon.position.shape[0]
|
566 |
+
|
567 |
+
def _env_step(runner_state, unused):
|
568 |
+
train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
|
569 |
+
|
570 |
+
# SELECT ACTION
|
571 |
+
rng, _rng = jax.random.split(rng)
|
572 |
+
obs_batch = last_obs
|
573 |
+
ac_in = (
|
574 |
+
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
|
575 |
+
last_done[np.newaxis, :],
|
576 |
+
)
|
577 |
+
hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
|
578 |
+
action = pi.sample(seed=_rng).squeeze()
|
579 |
+
log_prob = pi.log_prob(action)
|
580 |
+
env_act = action
|
581 |
+
|
582 |
+
# STEP ENV
|
583 |
+
rng, _rng = jax.random.split(rng)
|
584 |
+
rng_step = jax.random.split(_rng, config["num_train_envs"])
|
585 |
+
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
586 |
+
rng_step, env_state, env_act, env_params
|
587 |
+
)
|
588 |
+
done_batch = done
|
589 |
+
transition = Transition(
|
590 |
+
done,
|
591 |
+
last_done,
|
592 |
+
action.squeeze(),
|
593 |
+
value.squeeze(),
|
594 |
+
reward,
|
595 |
+
log_prob.squeeze(),
|
596 |
+
obs_batch,
|
597 |
+
info,
|
598 |
+
)
|
599 |
+
runner_state = (train_state, env_state, start_state, obsv, done_batch, hstate, update_steps, rng)
|
600 |
+
return runner_state, (transition)
|
601 |
+
|
602 |
+
initial_hstate = runner_state[-3]
|
603 |
+
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"])
|
604 |
+
|
605 |
+
# CALCULATE ADVANTAGE
|
606 |
+
train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
|
607 |
+
last_obs_batch = last_obs # batchify(last_obs, env.agents, config["num_train_envs"])
|
608 |
+
ac_in = (
|
609 |
+
jax.tree.map(lambda x: x[np.newaxis, :], last_obs_batch),
|
610 |
+
last_done[np.newaxis, :],
|
611 |
+
)
|
612 |
+
_, _, last_val = network.apply(train_state.params, hstate, ac_in)
|
613 |
+
last_val = last_val.squeeze()
|
614 |
+
|
615 |
+
def _calculate_gae(traj_batch, last_val):
|
616 |
+
def _get_advantages(gae_and_next_value, transition: Transition):
|
617 |
+
gae, next_value = gae_and_next_value
|
618 |
+
done, value, reward = (
|
619 |
+
transition.global_done,
|
620 |
+
transition.value,
|
621 |
+
transition.reward,
|
622 |
+
)
|
623 |
+
delta = reward + config["gamma"] * next_value * (1 - done) - value
|
624 |
+
gae = delta + config["gamma"] * config["gae_lambda"] * (1 - done) * gae
|
625 |
+
return (gae, value), gae
|
626 |
+
|
627 |
+
_, advantages = jax.lax.scan(
|
628 |
+
_get_advantages,
|
629 |
+
(jnp.zeros_like(last_val), last_val),
|
630 |
+
traj_batch,
|
631 |
+
reverse=True,
|
632 |
+
unroll=16,
|
633 |
+
)
|
634 |
+
return advantages, advantages + traj_batch.value
|
635 |
+
|
636 |
+
advantages, targets = _calculate_gae(traj_batch, last_val)
|
637 |
+
|
638 |
+
# UPDATE NETWORK
|
639 |
+
def _update_epoch(update_state, unused):
|
640 |
+
def _update_minbatch(train_state, batch_info):
|
641 |
+
init_hstate, traj_batch, advantages, targets = batch_info
|
642 |
+
|
643 |
+
def _loss_fn_masked(params, init_hstate, traj_batch, gae, targets):
|
644 |
+
|
645 |
+
# RERUN NETWORK
|
646 |
+
_, pi, value = network.apply(
|
647 |
+
params,
|
648 |
+
jax.tree.map(lambda x: x.transpose(), init_hstate),
|
649 |
+
(traj_batch.obs, traj_batch.done),
|
650 |
+
)
|
651 |
+
log_prob = pi.log_prob(traj_batch.action)
|
652 |
+
|
653 |
+
# CALCULATE VALUE LOSS
|
654 |
+
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
|
655 |
+
-config["clip_eps"], config["clip_eps"]
|
656 |
+
)
|
657 |
+
value_losses = jnp.square(value - targets)
|
658 |
+
value_losses_clipped = jnp.square(value_pred_clipped - targets)
|
659 |
+
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped)
|
660 |
+
critic_loss = config["vf_coef"] * value_loss.mean()
|
661 |
+
|
662 |
+
# CALCULATE ACTOR LOSS
|
663 |
+
logratio = log_prob - traj_batch.log_prob
|
664 |
+
ratio = jnp.exp(logratio)
|
665 |
+
# if env.do_sep_reward: gae = gae.sum(axis=-1)
|
666 |
+
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
|
667 |
+
loss_actor1 = ratio * gae
|
668 |
+
loss_actor2 = (
|
669 |
+
jnp.clip(
|
670 |
+
ratio,
|
671 |
+
1.0 - config["clip_eps"],
|
672 |
+
1.0 + config["clip_eps"],
|
673 |
+
)
|
674 |
+
* gae
|
675 |
+
)
|
676 |
+
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
|
677 |
+
loss_actor = loss_actor.mean()
|
678 |
+
entropy = pi.entropy().mean()
|
679 |
+
|
680 |
+
approx_kl = jax.lax.stop_gradient(((ratio - 1) - logratio).mean())
|
681 |
+
clipfrac = jax.lax.stop_gradient((jnp.abs(ratio - 1) > config["clip_eps"]).mean())
|
682 |
+
|
683 |
+
total_loss = loss_actor + critic_loss - config["ent_coef"] * entropy
|
684 |
+
return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clipfrac)
|
685 |
+
|
686 |
+
grad_fn = jax.value_and_grad(_loss_fn_masked, has_aux=True)
|
687 |
+
total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
|
688 |
+
train_state = train_state.apply_gradients(grads=grads)
|
689 |
+
return train_state, total_loss
|
690 |
+
|
691 |
+
(
|
692 |
+
train_state,
|
693 |
+
init_hstate,
|
694 |
+
traj_batch,
|
695 |
+
advantages,
|
696 |
+
targets,
|
697 |
+
rng,
|
698 |
+
) = update_state
|
699 |
+
rng, _rng = jax.random.split(rng)
|
700 |
+
|
701 |
+
init_hstate = jax.tree.map(lambda x: jnp.reshape(x, (256, config["num_train_envs"])), init_hstate)
|
702 |
+
batch = (
|
703 |
+
init_hstate,
|
704 |
+
traj_batch,
|
705 |
+
advantages.squeeze(),
|
706 |
+
targets.squeeze(),
|
707 |
+
)
|
708 |
+
permutation = jax.random.permutation(_rng, config["num_train_envs"])
|
709 |
+
|
710 |
+
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch)
|
711 |
+
|
712 |
+
minibatches = jax.tree_util.tree_map(
|
713 |
+
lambda x: jnp.swapaxes(
|
714 |
+
jnp.reshape(
|
715 |
+
x,
|
716 |
+
[x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]),
|
717 |
+
),
|
718 |
+
1,
|
719 |
+
0,
|
720 |
+
),
|
721 |
+
shuffled_batch,
|
722 |
+
)
|
723 |
+
|
724 |
+
train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
|
725 |
+
# total_loss = jax.tree.map(lambda x: x.mean(), total_loss)
|
726 |
+
update_state = (
|
727 |
+
train_state,
|
728 |
+
init_hstate,
|
729 |
+
traj_batch,
|
730 |
+
advantages,
|
731 |
+
targets,
|
732 |
+
rng,
|
733 |
+
)
|
734 |
+
return update_state, total_loss
|
735 |
+
|
736 |
+
# init_hstate = initial_hstate[None, :].squeeze().transpose()
|
737 |
+
init_hstate = jax.tree.map(lambda x: x[None, :].squeeze().transpose(), initial_hstate)
|
738 |
+
update_state = (
|
739 |
+
train_state,
|
740 |
+
init_hstate,
|
741 |
+
traj_batch,
|
742 |
+
advantages,
|
743 |
+
targets,
|
744 |
+
rng,
|
745 |
+
)
|
746 |
+
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"])
|
747 |
+
train_state = update_state[0]
|
748 |
+
metric = traj_batch.info
|
749 |
+
metric = jax.tree.map(
|
750 |
+
lambda x: x.reshape((config["num_steps"], config["num_train_envs"])), # , env.num_agents
|
751 |
+
traj_batch.info,
|
752 |
+
)
|
753 |
+
rng = update_state[-1]
|
754 |
+
|
755 |
+
def callback(metric):
|
756 |
+
dones = metric["dones"]
|
757 |
+
wandb.log(
|
758 |
+
{
|
759 |
+
"episode_return": (metric["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
|
760 |
+
"episode_solved": (metric["returned_episode_solved"] * dones).sum() / jnp.maximum(1, dones.sum()),
|
761 |
+
"episode_length": (metric["returned_episode_lengths"] * dones).sum() / jnp.maximum(1, dones.sum()),
|
762 |
+
"timing/num_env_steps": int(
|
763 |
+
int(metric["update_steps"]) * int(config["num_train_envs"]) * int(config["num_steps"])
|
764 |
+
),
|
765 |
+
"timing/num_updates": metric["update_steps"],
|
766 |
+
**metric["loss_info"],
|
767 |
+
}
|
768 |
+
)
|
769 |
+
|
770 |
+
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
|
771 |
+
metric["loss_info"] = {
|
772 |
+
"loss/total_loss": loss_info[0],
|
773 |
+
"loss/value_loss": loss_info[1][0],
|
774 |
+
"loss/policy_loss": loss_info[1][1],
|
775 |
+
"loss/entropy_loss": loss_info[1][2],
|
776 |
+
}
|
777 |
+
metric["dones"] = traj_batch.done
|
778 |
+
metric["update_steps"] = update_steps
|
779 |
+
jax.experimental.io_callback(callback, None, metric)
|
780 |
+
|
781 |
+
# SAMPLE NEW ENVS
|
782 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
783 |
+
rng_reset = jax.random.split(_rng, config["num_envs_to_generate"])
|
784 |
+
|
785 |
+
new_levels = sample_random_levels(_rng2, config["num_envs_to_generate"])
|
786 |
+
obsv_gen, env_state_gen = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
|
787 |
+
|
788 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
789 |
+
sampled_env_instances_idxs = jax.random.randint(_rng, (config["num_envs_from_sampled"],), 0, num_env_instances)
|
790 |
+
sampled_env_instances = jax.tree.map(lambda x: x.at[sampled_env_instances_idxs].get(), instances)
|
791 |
+
myrng = jax.random.split(_rng2, config["num_envs_from_sampled"])
|
792 |
+
obsv_sampled, env_state_sampled = jax.vmap(env.reset_to_level, in_axes=(0, 0))(myrng, sampled_env_instances)
|
793 |
+
|
794 |
+
obsv = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), obsv_gen, obsv_sampled)
|
795 |
+
env_state = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), env_state_gen, env_state_sampled)
|
796 |
+
|
797 |
+
start_state = env_state
|
798 |
+
hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
|
799 |
+
|
800 |
+
update_steps = update_steps + 1
|
801 |
+
runner_state = (
|
802 |
+
train_state,
|
803 |
+
env_state,
|
804 |
+
start_state,
|
805 |
+
obsv,
|
806 |
+
jnp.zeros((config["num_train_envs"]), dtype=bool),
|
807 |
+
hstate,
|
808 |
+
update_steps,
|
809 |
+
rng,
|
810 |
+
)
|
811 |
+
return (runner_state, instances), metric
|
812 |
+
|
813 |
+
def log_buffer(learnability, levels, epoch):
|
814 |
+
num_samples = levels.polygon.position.shape[0]
|
815 |
+
states = levels
|
816 |
+
rows = 2
|
817 |
+
fig, axes = plt.subplots(rows, int(num_samples / rows), figsize=(20, 10))
|
818 |
+
axes = axes.flatten()
|
819 |
+
all_imgs = jax.vmap(render_fn)(states)
|
820 |
+
for i, ax in enumerate(axes):
|
821 |
+
# ax.imshow(train_state.plr_buffer.get_sample(i))
|
822 |
+
score = learnability[i]
|
823 |
+
ax.imshow(all_imgs[i] / 255.0)
|
824 |
+
ax.set_xticks([])
|
825 |
+
ax.set_yticks([])
|
826 |
+
ax.set_title(f"learnability: {score:.3f}")
|
827 |
+
ax.set_aspect("equal", "box")
|
828 |
+
|
829 |
+
plt.tight_layout()
|
830 |
+
fig.canvas.draw()
|
831 |
+
im = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
832 |
+
plt.close()
|
833 |
+
return {"maps": wandb.Image(im)}
|
834 |
+
|
835 |
+
@jax.jit
|
836 |
+
def train_and_eval_step(runner_state, eval_rng):
|
837 |
+
|
838 |
+
learnability_rng, eval_singleton_rng, eval_sampled_rng, _rng = jax.random.split(eval_rng, 4)
|
839 |
+
# TRAIN
|
840 |
+
learnabilty_scores, instances, test_metrics = get_learnability_set(learnability_rng, runner_state[0].params)
|
841 |
+
|
842 |
+
if config["log_learnability_before_after"]:
|
843 |
+
learn_scores_before, success_score_before = log_buffer_learnability(
|
844 |
+
learnability_rng, runner_state[0], instances
|
845 |
+
)
|
846 |
+
|
847 |
+
print("instance size", sum(x.size for x in jax.tree_util.tree_leaves(instances)))
|
848 |
+
|
849 |
+
runner_state_instances = (runner_state, instances)
|
850 |
+
runner_state_instances, metrics = jax.lax.scan(train_step, runner_state_instances, None, config["eval_freq"])
|
851 |
+
|
852 |
+
if config["log_learnability_before_after"]:
|
853 |
+
learn_scores_after, success_score_after = log_buffer_learnability(
|
854 |
+
learnability_rng, runner_state_instances[0][0], instances
|
855 |
+
)
|
856 |
+
|
857 |
+
# EVAL
|
858 |
+
rng, rng_eval = jax.random.split(eval_singleton_rng)
|
859 |
+
(states, cum_rewards, _, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(eval, (0, None))(
|
860 |
+
jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
|
861 |
+
)
|
862 |
+
all_eval_eplens = episode_lengths
|
863 |
+
|
864 |
+
# Collect Metrics
|
865 |
+
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
|
866 |
+
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
|
867 |
+
1, eval_dones.sum(axis=1)
|
868 |
+
)
|
869 |
+
eval_solves = eval_solves.mean(axis=0)
|
870 |
+
# just grab the first run
|
871 |
+
states, episode_lengths = jax.tree_util.tree_map(
|
872 |
+
lambda x: x[0], (states, episode_lengths)
|
873 |
+
) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
|
874 |
+
# And one attempt
|
875 |
+
states = jax.tree_util.tree_map(lambda x: x[:, :], states)
|
876 |
+
episode_lengths = episode_lengths[:]
|
877 |
+
images = jax.vmap(jax.vmap(render_fn_eval))(
|
878 |
+
states.env_state.env_state.env_state
|
879 |
+
) # (num_steps, num_eval_levels, ...)
|
880 |
+
frames = images.transpose(
|
881 |
+
0, 1, 4, 2, 3
|
882 |
+
) # WandB expects color channel before image dimensions when dealing with animations for some reason
|
883 |
+
|
884 |
+
test_metrics["update_count"] = runner_state[-2]
|
885 |
+
test_metrics["eval_returns"] = eval_returns
|
886 |
+
test_metrics["eval_ep_lengths"] = episode_lengths
|
887 |
+
test_metrics["eval_animation"] = (frames, episode_lengths)
|
888 |
+
|
889 |
+
# Eval on sampled
|
890 |
+
dr_states, dr_cum_rewards, _, dr_episode_lengths, dr_infos = jax.vmap(eval_on_dr_levels, (0, None))(
|
891 |
+
jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
|
892 |
+
)
|
893 |
+
|
894 |
+
eval_dr_returns = dr_cum_rewards.mean(axis=0).mean()
|
895 |
+
eval_dr_eplen = dr_episode_lengths.mean(axis=0).mean()
|
896 |
+
|
897 |
+
test_metrics["eval/mean_eval_return_sampled"] = eval_dr_returns
|
898 |
+
my_eval_dones = dr_infos["returned_episode"]
|
899 |
+
eval_dr_solves = (dr_infos["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
|
900 |
+
1, my_eval_dones.sum(axis=1)
|
901 |
+
)
|
902 |
+
|
903 |
+
test_metrics["eval/mean_eval_solve_rate_sampled"] = eval_dr_solves
|
904 |
+
test_metrics["eval/mean_eval_eplen_sampled"] = eval_dr_eplen
|
905 |
+
|
906 |
+
# Collect Metrics
|
907 |
+
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
|
908 |
+
|
909 |
+
log_dict = {}
|
910 |
+
|
911 |
+
log_dict["to_remove"] = {
|
912 |
+
"eval_return": eval_returns,
|
913 |
+
"eval_solve_rate": eval_solves,
|
914 |
+
"eval_eplen": all_eval_eplens,
|
915 |
+
}
|
916 |
+
|
917 |
+
for i, name in enumerate(config["eval_levels"]):
|
918 |
+
log_dict[f"eval_avg_return/{name}"] = eval_returns[i]
|
919 |
+
log_dict[f"eval_avg_solve_rate/{name}"] = eval_solves[i]
|
920 |
+
|
921 |
+
log_dict.update({"eval/mean_eval_return": eval_returns.mean()})
|
922 |
+
log_dict.update({"eval/mean_eval_solve_rate": eval_solves.mean()})
|
923 |
+
log_dict.update({"eval/mean_eval_eplen": all_eval_eplens.mean()})
|
924 |
+
|
925 |
+
test_metrics.update(log_dict)
|
926 |
+
|
927 |
+
runner_state, _ = runner_state_instances
|
928 |
+
test_metrics["update_count"] = runner_state[-2]
|
929 |
+
|
930 |
+
top_instances = jax.tree.map(lambda x: x.at[-5:].get(), instances)
|
931 |
+
|
932 |
+
# Eval on top learnable levels
|
933 |
+
tl_states, tl_cum_rewards, _, tl_episode_lengths, tl_infos = jax.vmap(
|
934 |
+
eval_on_top_learnable_levels, (0, None, None)
|
935 |
+
)(jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0], top_instances)
|
936 |
+
|
937 |
+
# just grab the first run
|
938 |
+
states, episode_lengths = jax.tree_util.tree_map(
|
939 |
+
lambda x: x[0], (tl_states, tl_episode_lengths)
|
940 |
+
) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
|
941 |
+
# And one attempt
|
942 |
+
states = jax.tree_util.tree_map(lambda x: x[:, :], states)
|
943 |
+
episode_lengths = episode_lengths[:]
|
944 |
+
images = jax.vmap(jax.vmap(render_fn))(
|
945 |
+
states.env_state.env_state.env_state
|
946 |
+
) # (num_steps, num_eval_levels, ...)
|
947 |
+
frames = images.transpose(
|
948 |
+
0, 1, 4, 2, 3
|
949 |
+
) # WandB expects color channel before image dimensions when dealing with animations for some reason
|
950 |
+
|
951 |
+
test_metrics["top_learnable_animation"] = (frames, episode_lengths, tl_cum_rewards)
|
952 |
+
|
953 |
+
if config["log_learnability_before_after"]:
|
954 |
+
|
955 |
+
def single(x, name):
|
956 |
+
return {
|
957 |
+
f"{name}_mean": x.mean(),
|
958 |
+
f"{name}_std": x.std(),
|
959 |
+
f"{name}_min": x.min(),
|
960 |
+
f"{name}_max": x.max(),
|
961 |
+
f"{name}_median": jnp.median(x),
|
962 |
+
}
|
963 |
+
|
964 |
+
test_metrics["learnability_log_v2/"] = {
|
965 |
+
**single(learn_scores_before, "learnability_before"),
|
966 |
+
**single(learn_scores_after, "learnability_after"),
|
967 |
+
**single(success_score_before, "success_score_before"),
|
968 |
+
**single(success_score_after, "success_score_after"),
|
969 |
+
}
|
970 |
+
|
971 |
+
return runner_state, (learnabilty_scores.at[-20:].get(), top_instances), test_metrics
|
972 |
+
|
973 |
+
rng, _rng = jax.random.split(rng)
|
974 |
+
runner_state = (
|
975 |
+
train_state,
|
976 |
+
env_state,
|
977 |
+
start_state,
|
978 |
+
obsv,
|
979 |
+
jnp.zeros((config["num_train_envs"]), dtype=bool),
|
980 |
+
init_hstate,
|
981 |
+
0,
|
982 |
+
_rng,
|
983 |
+
)
|
984 |
+
|
985 |
+
def log_eval(stats):
|
986 |
+
log_dict = {}
|
987 |
+
|
988 |
+
to_remove = stats["to_remove"]
|
989 |
+
del stats["to_remove"]
|
990 |
+
|
991 |
+
def _aggregate_per_size(values, name):
|
992 |
+
to_return = {}
|
993 |
+
for group_name, indices in eval_group_indices.items():
|
994 |
+
to_return[f"{name}_{group_name}"] = values[indices].mean()
|
995 |
+
return to_return
|
996 |
+
|
997 |
+
env_steps = stats["update_count"] * config["num_train_envs"] * config["num_steps"]
|
998 |
+
env_steps_delta = config["eval_freq"] * config["num_train_envs"] * config["num_steps"]
|
999 |
+
time_now = time.time()
|
1000 |
+
log_dict = {
|
1001 |
+
"timing/num_updates": stats["update_count"],
|
1002 |
+
"timing/num_env_steps": env_steps,
|
1003 |
+
"timing/sps": env_steps_delta / stats["time_delta"],
|
1004 |
+
"timing/sps_agg": env_steps / (time_now - time_start),
|
1005 |
+
}
|
1006 |
+
log_dict.update(_aggregate_per_size(to_remove["eval_return"], "eval_aggregate/return"))
|
1007 |
+
log_dict.update(_aggregate_per_size(to_remove["eval_solve_rate"], "eval_aggregate/solve_rate"))
|
1008 |
+
|
1009 |
+
for i in range((len(config["eval_levels"]))):
|
1010 |
+
frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
|
1011 |
+
frames = np.array(frames[:episode_length])
|
1012 |
+
log_dict.update(
|
1013 |
+
{
|
1014 |
+
f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
|
1015 |
+
frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})"
|
1016 |
+
)
|
1017 |
+
}
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
for j in range(5):
|
1021 |
+
frames, episode_length, cum_rewards = (
|
1022 |
+
stats["top_learnable_animation"][0][:, j],
|
1023 |
+
stats["top_learnable_animation"][1][j],
|
1024 |
+
stats["top_learnable_animation"][2][:, j],
|
1025 |
+
) # num attempts
|
1026 |
+
rr = "|".join([f"{r:<.2f}" for r in cum_rewards])
|
1027 |
+
frames = np.array(frames[:episode_length])
|
1028 |
+
log_dict.update(
|
1029 |
+
{
|
1030 |
+
f"media/tl_animation_{j}": wandb.Video(
|
1031 |
+
frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})\n{rr}"
|
1032 |
+
)
|
1033 |
+
}
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
stats.update(log_dict)
|
1037 |
+
wandb.log(stats, step=stats["update_count"])
|
1038 |
+
|
1039 |
+
checkpoint_steps = config["checkpoint_save_freq"]
|
1040 |
+
assert config["num_updates"] % config["eval_freq"] == 0, "num_updates must be divisible by eval_freq"
|
1041 |
+
|
1042 |
+
for eval_step in range(int(config["num_updates"] // config["eval_freq"])):
|
1043 |
+
start_time = time.time()
|
1044 |
+
rng, eval_rng = jax.random.split(rng)
|
1045 |
+
runner_state, instances, metrics = train_and_eval_step(runner_state, eval_rng)
|
1046 |
+
curr_time = time.time()
|
1047 |
+
metrics.update(log_buffer(*instances, metrics["update_count"]))
|
1048 |
+
metrics["time_delta"] = curr_time - start_time
|
1049 |
+
metrics["steps_per_section"] = (config["eval_freq"] * config["num_steps"] * config["num_train_envs"]) / metrics[
|
1050 |
+
"time_delta"
|
1051 |
+
]
|
1052 |
+
log_eval(metrics)
|
1053 |
+
if ((eval_step + 1) * config["eval_freq"]) % checkpoint_steps == 0:
|
1054 |
+
if config["save_path"] is not None:
|
1055 |
+
steps = int(metrics["update_count"]) * int(config["num_train_envs"]) * int(config["num_steps"])
|
1056 |
+
# save_params_to_wandb(runner_state[0].params, steps, config)
|
1057 |
+
save_model_to_wandb(runner_state[0], steps, config)
|
1058 |
+
|
1059 |
+
if config["save_path"] is not None:
|
1060 |
+
# save_params_to_wandb(runner_state[0].params, config["total_timesteps"], config)
|
1061 |
+
save_model_to_wandb(runner_state[0], config["total_timesteps"], config)
|
1062 |
+
|
1063 |
+
|
1064 |
+
if __name__ == "__main__":
|
1065 |
+
# with jax.disable_jit():
|
1066 |
+
# main()
|
1067 |
+
main()
|
Kinetix/images/bb.gif
ADDED
![]() |
Kinetix/images/cartpole.gif
ADDED
![]() |