sgoodfriend
commited on
Commit
•
7c70ebe
1
Parent(s):
950effc
A2C playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +14 -11
- pyproject.toml +23 -2
- replay.meta.json +1 -1
- rl_algo_impls/a2c/a2c.py +13 -19
- rl_algo_impls/a2c/optimize.py +1 -1
- rl_algo_impls/benchmark_publish.py +2 -2
- rl_algo_impls/compare_runs.py +2 -1
- rl_algo_impls/dqn/policy.py +14 -7
- rl_algo_impls/dqn/q_net.py +6 -6
- rl_algo_impls/huggingface_publish.py +1 -1
- rl_algo_impls/hyperparams/a2c.yml +17 -13
- rl_algo_impls/hyperparams/dqn.yml +1 -1
- rl_algo_impls/hyperparams/ppo.yml +125 -5
- rl_algo_impls/hyperparams/vpg.yml +4 -4
- rl_algo_impls/optimize.py +5 -4
- rl_algo_impls/ppo/ppo.py +248 -227
- rl_algo_impls/runner/config.py +9 -3
- rl_algo_impls/runner/evaluate.py +2 -2
- rl_algo_impls/runner/running_utils.py +33 -18
- rl_algo_impls/runner/train.py +11 -10
- rl_algo_impls/shared/actor/__init__.py +2 -0
- rl_algo_impls/shared/actor/actor.py +42 -0
- rl_algo_impls/shared/actor/categorical.py +64 -0
- rl_algo_impls/shared/actor/gaussian.py +61 -0
- rl_algo_impls/shared/actor/gridnet.py +108 -0
- rl_algo_impls/shared/actor/gridnet_decoder.py +80 -0
- rl_algo_impls/shared/actor/make_actor.py +95 -0
- rl_algo_impls/shared/actor/multi_discrete.py +101 -0
- rl_algo_impls/shared/{policy/actor.py → actor/state_dependent_noise.py} +33 -143
- rl_algo_impls/shared/callbacks/eval_callback.py +26 -9
- rl_algo_impls/shared/encoder/__init__.py +2 -0
- rl_algo_impls/shared/encoder/cnn.py +72 -0
- rl_algo_impls/shared/encoder/encoder.py +73 -0
- rl_algo_impls/shared/encoder/gridnet_encoder.py +64 -0
- rl_algo_impls/shared/encoder/impala_cnn.py +92 -0
- rl_algo_impls/shared/encoder/microrts_cnn.py +45 -0
- rl_algo_impls/shared/encoder/nature_cnn.py +53 -0
- rl_algo_impls/shared/gae.py +29 -2
- rl_algo_impls/shared/module/feature_extractor.py +0 -215
- rl_algo_impls/shared/module/module.py +6 -3
- rl_algo_impls/shared/policy/critic.py +22 -10
- rl_algo_impls/shared/policy/on_policy.py +57 -34
- rl_algo_impls/shared/policy/policy.py +6 -1
- rl_algo_impls/shared/schedule.py +29 -1
- rl_algo_impls/shared/stats.py +24 -6
- rl_algo_impls/shared/vec_env/__init__.py +1 -0
- rl_algo_impls/shared/vec_env/make_env.py +66 -0
- rl_algo_impls/shared/vec_env/microrts.py +94 -0
- rl_algo_impls/shared/vec_env/microrts_compat.py +49 -0
- rl_algo_impls/shared/vec_env/procgen.py +81 -0
README.md
CHANGED
@@ -23,17 +23,17 @@ model-index:
|
|
23 |
|
24 |
This is a trained model of a **A2C** agent playing **PongNoFrameskip-v4** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
|
26 |
-
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
-
This model was trained from 3 trainings of **A2C** agents using different initial seeds. These agents were trained by checking out [
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:-------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
-
| a2c | PongNoFrameskip-v4 | 1 | 21 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/
|
35 |
-
| a2c | PongNoFrameskip-v4 | 2 | 21 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/
|
36 |
-
| a2c | PongNoFrameskip-v4 | 3 | 21 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
@@ -53,10 +53,10 @@ login`.
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
-
[
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
-
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
@@ -68,7 +68,7 @@ notebook.
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
-
commit the agent was trained on: [
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
@@ -83,7 +83,7 @@ notebook.
|
|
83 |
|
84 |
|
85 |
## Benchmarking (with Lambda Labs instance)
|
86 |
-
This and other models from https://api.wandb.ai/links/sgoodfriend/
|
87 |
Labs instance. In a Lambda Labs instance terminal:
|
88 |
```
|
89 |
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
@@ -105,6 +105,7 @@ can be used. However, this requires a Google Colab Pro+ subscription and running
|
|
105 |
This isn't exactly the format of hyperparams in hyperparams/a2c.yml, but instead the Wandb Run Config. However, it's very
|
106 |
close and has some additional data:
|
107 |
```
|
|
|
108 |
algo: a2c
|
109 |
algo_hyperparams:
|
110 |
ent_coef: 0.01
|
@@ -128,7 +129,9 @@ wandb_entity: null
|
|
128 |
wandb_group: null
|
129 |
wandb_project_name: rl-algo-impls-benchmarks
|
130 |
wandb_tags:
|
131 |
-
-
|
132 |
-
-
|
|
|
|
|
133 |
|
134 |
```
|
|
|
23 |
|
24 |
This is a trained model of a **A2C** agent playing **PongNoFrameskip-v4** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
|
26 |
+
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/7lx79bf0.
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
+
This model was trained from 3 trainings of **A2C** agents using different initial seeds. These agents were trained by checking out [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:-------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
+
| a2c | PongNoFrameskip-v4 | 1 | 21 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/zis59lo4) |
|
35 |
+
| a2c | PongNoFrameskip-v4 | 2 | 21 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/b8na9vjp) |
|
36 |
+
| a2c | PongNoFrameskip-v4 | 3 | 21 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/kka3ymvo) |
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
+
[0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c).
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
+
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/kka3ymvo
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
+
commit the agent was trained on: [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). While
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
|
|
83 |
|
84 |
|
85 |
## Benchmarking (with Lambda Labs instance)
|
86 |
+
This and other models from https://api.wandb.ai/links/sgoodfriend/7lx79bf0 were generated by running a script on a Lambda
|
87 |
Labs instance. In a Lambda Labs instance terminal:
|
88 |
```
|
89 |
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
|
|
105 |
This isn't exactly the format of hyperparams in hyperparams/a2c.yml, but instead the Wandb Run Config. However, it's very
|
106 |
close and has some additional data:
|
107 |
```
|
108 |
+
additional_keys_to_log: []
|
109 |
algo: a2c
|
110 |
algo_hyperparams:
|
111 |
ent_coef: 0.01
|
|
|
129 |
wandb_group: null
|
130 |
wandb_project_name: rl-algo-impls-benchmarks
|
131 |
wandb_tags:
|
132 |
+
- benchmark_0511de3
|
133 |
+
- host_152-67-249-42
|
134 |
+
- branch_main
|
135 |
+
- v0.0.8
|
136 |
|
137 |
```
|
pyproject.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
[project]
|
2 |
name = "rl_algo_impls"
|
3 |
-
version = "0.0.
|
4 |
description = "Implementations of reinforcement learning algorithms"
|
5 |
authors = [
|
6 |
{name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},
|
@@ -35,6 +35,7 @@ dependencies = [
|
|
35 |
"dash",
|
36 |
"kaleido",
|
37 |
"PyYAML",
|
|
|
38 |
]
|
39 |
|
40 |
[tool.setuptools]
|
@@ -55,10 +56,30 @@ procgen = [
|
|
55 |
"glfw >= 1.12.0, < 1.13",
|
56 |
"procgen; platform_machine=='x86_64'",
|
57 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
[project.urls]
|
60 |
"Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
|
61 |
|
62 |
[build-system]
|
63 |
requires = ["setuptools==65.5.0", "setuptools-scm"]
|
64 |
-
build-backend = "setuptools.build_meta"
|
|
|
|
|
|
|
|
1 |
[project]
|
2 |
name = "rl_algo_impls"
|
3 |
+
version = "0.0.8"
|
4 |
description = "Implementations of reinforcement learning algorithms"
|
5 |
authors = [
|
6 |
{name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},
|
|
|
35 |
"dash",
|
36 |
"kaleido",
|
37 |
"PyYAML",
|
38 |
+
"scikit-learn",
|
39 |
]
|
40 |
|
41 |
[tool.setuptools]
|
|
|
56 |
"glfw >= 1.12.0, < 1.13",
|
57 |
"procgen; platform_machine=='x86_64'",
|
58 |
]
|
59 |
+
microrts-old = [
|
60 |
+
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
61 |
+
"gym-microrts == 0.2.0", # Match ppo-implementation-details
|
62 |
+
]
|
63 |
+
microrts = [
|
64 |
+
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
65 |
+
"gym-microrts == 0.3.2",
|
66 |
+
]
|
67 |
+
jupyter = [
|
68 |
+
"jupyter",
|
69 |
+
"notebook"
|
70 |
+
]
|
71 |
+
all = [
|
72 |
+
"rl-algo-impls[test]",
|
73 |
+
"rl-algo-impls[procgen]",
|
74 |
+
"rl-algo-impls[microrts]",
|
75 |
+
]
|
76 |
|
77 |
[project.urls]
|
78 |
"Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
|
79 |
|
80 |
[build-system]
|
81 |
requires = ["setuptools==65.5.0", "setuptools-scm"]
|
82 |
+
build-backend = "setuptools.build_meta"
|
83 |
+
|
84 |
+
[tool.isort]
|
85 |
+
profile = "black"
|
replay.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "160x210", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/tmp/
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "160x210", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/tmp/tmpi5c95nkr/a2c-PongNoFrameskip-v4/replay.mp4"]}, "episode": {"r": 21.0, "l": 7976, "t": 4.376118}}
|
rl_algo_impls/a2c/a2c.py
CHANGED
@@ -10,6 +10,7 @@ from typing import Optional, TypeVar
|
|
10 |
|
11 |
from rl_algo_impls.shared.algorithm import Algorithm
|
12 |
from rl_algo_impls.shared.callbacks.callback import Callback
|
|
|
13 |
from rl_algo_impls.shared.policy.on_policy import ActorCritic
|
14 |
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
|
15 |
from rl_algo_impls.shared.stats import log_scalars
|
@@ -84,12 +85,12 @@ class A2C(Algorithm):
|
|
84 |
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
|
85 |
actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
|
86 |
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
87 |
-
episode_starts = np.zeros(epoch_dim, dtype=np.
|
88 |
values = np.zeros(epoch_dim, dtype=np.float32)
|
89 |
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
90 |
|
91 |
next_obs = self.env.reset()
|
92 |
-
next_episode_starts = np.
|
93 |
|
94 |
timesteps_elapsed = start_timesteps
|
95 |
while timesteps_elapsed < start_timesteps + train_timesteps:
|
@@ -126,23 +127,16 @@ class A2C(Algorithm):
|
|
126 |
clamped_action
|
127 |
)
|
128 |
|
129 |
-
advantages =
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
|
140 |
-
)
|
141 |
-
last_gae_lam = (
|
142 |
-
delta
|
143 |
-
+ self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
|
144 |
-
)
|
145 |
-
advantages[t] = last_gae_lam
|
146 |
returns = advantages + values
|
147 |
|
148 |
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
|
|
|
10 |
|
11 |
from rl_algo_impls.shared.algorithm import Algorithm
|
12 |
from rl_algo_impls.shared.callbacks.callback import Callback
|
13 |
+
from rl_algo_impls.shared.gae import compute_advantages
|
14 |
from rl_algo_impls.shared.policy.on_policy import ActorCritic
|
15 |
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
|
16 |
from rl_algo_impls.shared.stats import log_scalars
|
|
|
85 |
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
|
86 |
actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
|
87 |
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
88 |
+
episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
|
89 |
values = np.zeros(epoch_dim, dtype=np.float32)
|
90 |
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
91 |
|
92 |
next_obs = self.env.reset()
|
93 |
+
next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
|
94 |
|
95 |
timesteps_elapsed = start_timesteps
|
96 |
while timesteps_elapsed < start_timesteps + train_timesteps:
|
|
|
127 |
clamped_action
|
128 |
)
|
129 |
|
130 |
+
advantages = compute_advantages(
|
131 |
+
rewards,
|
132 |
+
values,
|
133 |
+
episode_starts,
|
134 |
+
next_episode_starts,
|
135 |
+
next_obs,
|
136 |
+
self.policy,
|
137 |
+
self.gamma,
|
138 |
+
self.gae_lambda,
|
139 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
returns = advantages + values
|
141 |
|
142 |
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
|
rl_algo_impls/a2c/optimize.py
CHANGED
@@ -3,7 +3,7 @@ import optuna
|
|
3 |
from copy import deepcopy
|
4 |
|
5 |
from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
|
6 |
-
from rl_algo_impls.
|
7 |
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
8 |
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
|
|
|
3 |
from copy import deepcopy
|
4 |
|
5 |
from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
|
6 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
7 |
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
8 |
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
|
rl_algo_impls/benchmark_publish.py
CHANGED
@@ -54,8 +54,8 @@ def benchmark_publish() -> None:
|
|
54 |
"--virtual-display", action="store_true", help="Use headless virtual display"
|
55 |
)
|
56 |
# parser.set_defaults(
|
57 |
-
# wandb_tags=["
|
58 |
-
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/
|
59 |
# envs=[],
|
60 |
# exclude_envs=[],
|
61 |
# )
|
|
|
54 |
"--virtual-display", action="store_true", help="Use headless virtual display"
|
55 |
)
|
56 |
# parser.set_defaults(
|
57 |
+
# wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
|
58 |
+
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
|
59 |
# envs=[],
|
60 |
# exclude_envs=[],
|
61 |
# )
|
rl_algo_impls/compare_runs.py
CHANGED
@@ -194,5 +194,6 @@ def compare_runs() -> None:
|
|
194 |
df.loc["mean"] = df.mean(numeric_only=True)
|
195 |
print(df.to_markdown())
|
196 |
|
|
|
197 |
if __name__ == "__main__":
|
198 |
-
compare_runs()
|
|
|
194 |
df.loc["mean"] = df.mean(numeric_only=True)
|
195 |
print(df.to_markdown())
|
196 |
|
197 |
+
|
198 |
if __name__ == "__main__":
|
199 |
+
compare_runs()
|
rl_algo_impls/dqn/policy.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
import numpy as np
|
2 |
import os
|
3 |
-
import torch
|
4 |
-
|
5 |
from typing import Optional, Sequence, TypeVar
|
6 |
|
|
|
|
|
|
|
7 |
from rl_algo_impls.dqn.q_net import QNetwork
|
8 |
from rl_algo_impls.shared.policy.policy import Policy
|
9 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
10 |
VecEnv,
|
11 |
VecEnvObs,
|
12 |
-
single_observation_space,
|
13 |
single_action_space,
|
|
|
14 |
)
|
15 |
|
16 |
DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
|
@@ -21,7 +21,7 @@ class DQNPolicy(Policy):
|
|
21 |
self,
|
22 |
env: VecEnv,
|
23 |
hidden_sizes: Sequence[int] = [],
|
24 |
-
|
25 |
cnn_style: str = "nature",
|
26 |
cnn_layers_init_orthogonal: Optional[bool] = None,
|
27 |
impala_channels: Sequence[int] = (16, 32, 32),
|
@@ -32,16 +32,23 @@ class DQNPolicy(Policy):
|
|
32 |
single_observation_space(env),
|
33 |
single_action_space(env),
|
34 |
hidden_sizes,
|
35 |
-
|
36 |
cnn_style=cnn_style,
|
37 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
38 |
impala_channels=impala_channels,
|
39 |
)
|
40 |
|
41 |
def act(
|
42 |
-
self,
|
|
|
|
|
|
|
|
|
43 |
) -> np.ndarray:
|
44 |
assert eps == 0 if deterministic else eps >= 0
|
|
|
|
|
|
|
45 |
if not deterministic and np.random.random() < eps:
|
46 |
return np.array(
|
47 |
[
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
from typing import Optional, Sequence, TypeVar
|
3 |
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
from rl_algo_impls.dqn.q_net import QNetwork
|
8 |
from rl_algo_impls.shared.policy.policy import Policy
|
9 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
10 |
VecEnv,
|
11 |
VecEnvObs,
|
|
|
12 |
single_action_space,
|
13 |
+
single_observation_space,
|
14 |
)
|
15 |
|
16 |
DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
|
|
|
21 |
self,
|
22 |
env: VecEnv,
|
23 |
hidden_sizes: Sequence[int] = [],
|
24 |
+
cnn_flatten_dim: int = 512,
|
25 |
cnn_style: str = "nature",
|
26 |
cnn_layers_init_orthogonal: Optional[bool] = None,
|
27 |
impala_channels: Sequence[int] = (16, 32, 32),
|
|
|
32 |
single_observation_space(env),
|
33 |
single_action_space(env),
|
34 |
hidden_sizes,
|
35 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
36 |
cnn_style=cnn_style,
|
37 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
38 |
impala_channels=impala_channels,
|
39 |
)
|
40 |
|
41 |
def act(
|
42 |
+
self,
|
43 |
+
obs: VecEnvObs,
|
44 |
+
eps: float = 0,
|
45 |
+
deterministic: bool = True,
|
46 |
+
action_masks: Optional[np.ndarray] = None,
|
47 |
) -> np.ndarray:
|
48 |
assert eps == 0 if deterministic else eps >= 0
|
49 |
+
assert (
|
50 |
+
action_masks is None
|
51 |
+
), f"action_masks not currently supported in {self.__class__.__name__}"
|
52 |
if not deterministic and np.random.random() < eps:
|
53 |
return np.array(
|
54 |
[
|
rl_algo_impls/dqn/q_net.py
CHANGED
@@ -1,11 +1,11 @@
|
|
|
|
|
|
1 |
import gym
|
2 |
import torch as th
|
3 |
import torch.nn as nn
|
4 |
-
|
5 |
from gym.spaces import Discrete
|
6 |
-
from typing import Optional, Sequence, Type
|
7 |
|
8 |
-
from rl_algo_impls.shared.
|
9 |
from rl_algo_impls.shared.module.module import mlp
|
10 |
|
11 |
|
@@ -16,17 +16,17 @@ class QNetwork(nn.Module):
|
|
16 |
action_space: gym.Space,
|
17 |
hidden_sizes: Sequence[int] = [],
|
18 |
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
|
19 |
-
|
20 |
cnn_style: str = "nature",
|
21 |
cnn_layers_init_orthogonal: Optional[bool] = None,
|
22 |
impala_channels: Sequence[int] = (16, 32, 32),
|
23 |
) -> None:
|
24 |
super().__init__()
|
25 |
assert isinstance(action_space, Discrete)
|
26 |
-
self._feature_extractor =
|
27 |
observation_space,
|
28 |
activation,
|
29 |
-
|
30 |
cnn_style=cnn_style,
|
31 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
32 |
impala_channels=impala_channels,
|
|
|
1 |
+
from typing import Optional, Sequence, Type
|
2 |
+
|
3 |
import gym
|
4 |
import torch as th
|
5 |
import torch.nn as nn
|
|
|
6 |
from gym.spaces import Discrete
|
|
|
7 |
|
8 |
+
from rl_algo_impls.shared.encoder import Encoder
|
9 |
from rl_algo_impls.shared.module.module import mlp
|
10 |
|
11 |
|
|
|
16 |
action_space: gym.Space,
|
17 |
hidden_sizes: Sequence[int] = [],
|
18 |
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
|
19 |
+
cnn_flatten_dim: int = 512,
|
20 |
cnn_style: str = "nature",
|
21 |
cnn_layers_init_orthogonal: Optional[bool] = None,
|
22 |
impala_channels: Sequence[int] = (16, 32, 32),
|
23 |
) -> None:
|
24 |
super().__init__()
|
25 |
assert isinstance(action_space, Discrete)
|
26 |
+
self._feature_extractor = Encoder(
|
27 |
observation_space,
|
28 |
activation,
|
29 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
30 |
cnn_style=cnn_style,
|
31 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
32 |
impala_channels=impala_channels,
|
rl_algo_impls/huggingface_publish.py
CHANGED
@@ -19,7 +19,7 @@ from pyvirtualdisplay.display import Display
|
|
19 |
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
|
20 |
from rl_algo_impls.runner.config import EnvHyperparams
|
21 |
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
|
22 |
-
from rl_algo_impls.
|
23 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
24 |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
25 |
|
|
|
19 |
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
|
20 |
from rl_algo_impls.runner.config import EnvHyperparams
|
21 |
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
|
22 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
23 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
24 |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
25 |
|
rl_algo_impls/hyperparams/a2c.yml
CHANGED
@@ -97,31 +97,35 @@ Walker2DBulletEnv-v0:
|
|
97 |
HopperBulletEnv-v0:
|
98 |
<<: *pybullet-defaults
|
99 |
|
|
|
100 |
CarRacing-v0:
|
101 |
n_timesteps: !!float 4e6
|
102 |
env_hyperparams:
|
103 |
-
n_envs:
|
104 |
frame_stack: 4
|
105 |
normalize: true
|
106 |
normalize_kwargs:
|
107 |
norm_obs: false
|
108 |
norm_reward: true
|
109 |
policy_hyperparams:
|
110 |
-
use_sde:
|
111 |
-
log_std_init: -
|
112 |
-
init_layers_orthogonal:
|
113 |
-
activation_fn:
|
114 |
share_features_extractor: false
|
115 |
-
|
116 |
hidden_sizes: [256]
|
117 |
algo_hyperparams:
|
118 |
-
n_steps:
|
119 |
-
learning_rate:
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
vf_coef: 0.
|
|
|
|
|
|
|
125 |
|
126 |
_atari: &atari-defaults
|
127 |
n_timesteps: !!float 1e7
|
|
|
97 |
HopperBulletEnv-v0:
|
98 |
<<: *pybullet-defaults
|
99 |
|
100 |
+
# Tuned
|
101 |
CarRacing-v0:
|
102 |
n_timesteps: !!float 4e6
|
103 |
env_hyperparams:
|
104 |
+
n_envs: 16
|
105 |
frame_stack: 4
|
106 |
normalize: true
|
107 |
normalize_kwargs:
|
108 |
norm_obs: false
|
109 |
norm_reward: true
|
110 |
policy_hyperparams:
|
111 |
+
use_sde: false
|
112 |
+
log_std_init: -1.3502584927786276
|
113 |
+
init_layers_orthogonal: true
|
114 |
+
activation_fn: tanh
|
115 |
share_features_extractor: false
|
116 |
+
cnn_flatten_dim: 256
|
117 |
hidden_sizes: [256]
|
118 |
algo_hyperparams:
|
119 |
+
n_steps: 16
|
120 |
+
learning_rate: 0.000025630993245026736
|
121 |
+
learning_rate_decay: linear
|
122 |
+
gamma: 0.99957617037542
|
123 |
+
gae_lambda: 0.949455676599436
|
124 |
+
ent_coef: !!float 1.707983205298309e-7
|
125 |
+
vf_coef: 0.10428178193833336
|
126 |
+
max_grad_norm: 0.5406643389792273
|
127 |
+
normalize_advantage: true
|
128 |
+
use_rms_prop: false
|
129 |
|
130 |
_atari: &atari-defaults
|
131 |
n_timesteps: !!float 1e7
|
rl_algo_impls/hyperparams/dqn.yml
CHANGED
@@ -108,7 +108,7 @@ _impala-atari: &impala-atari-defaults
|
|
108 |
<<: *atari-defaults
|
109 |
policy_hyperparams:
|
110 |
cnn_style: impala
|
111 |
-
|
112 |
init_layers_orthogonal: true
|
113 |
cnn_layers_init_orthogonal: false
|
114 |
|
|
|
108 |
<<: *atari-defaults
|
109 |
policy_hyperparams:
|
110 |
cnn_style: impala
|
111 |
+
cnn_flatten_dim: 256
|
112 |
init_layers_orthogonal: true
|
113 |
cnn_layers_init_orthogonal: false
|
114 |
|
rl_algo_impls/hyperparams/ppo.yml
CHANGED
@@ -112,7 +112,7 @@ CarRacing-v0: &carracing-defaults
|
|
112 |
init_layers_orthogonal: false
|
113 |
activation_fn: relu
|
114 |
share_features_extractor: false
|
115 |
-
|
116 |
hidden_sizes: [256]
|
117 |
algo_hyperparams:
|
118 |
n_steps: 512
|
@@ -152,7 +152,7 @@ _atari: &atari-defaults
|
|
152 |
vec_env_class: async
|
153 |
policy_hyperparams: &atari-policy-defaults
|
154 |
activation_fn: relu
|
155 |
-
algo_hyperparams:
|
156 |
n_steps: 128
|
157 |
batch_size: 256
|
158 |
n_epochs: 4
|
@@ -192,7 +192,7 @@ _impala-atari: &impala-atari-defaults
|
|
192 |
policy_hyperparams:
|
193 |
<<: *atari-policy-defaults
|
194 |
cnn_style: impala
|
195 |
-
|
196 |
init_layers_orthogonal: true
|
197 |
cnn_layers_init_orthogonal: false
|
198 |
|
@@ -212,6 +212,126 @@ impala-QbertNoFrameskip-v4:
|
|
212 |
<<: *impala-atari-defaults
|
213 |
env_id: QbertNoFrameskip-v4
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
216 |
n_timesteps: !!float 2e6
|
217 |
env_hyperparams: &pybullet-env-defaults
|
@@ -282,7 +402,7 @@ _procgen: &procgen-defaults
|
|
282 |
policy_hyperparams: &procgen-policy-defaults
|
283 |
activation_fn: relu
|
284 |
cnn_style: impala
|
285 |
-
|
286 |
init_layers_orthogonal: true
|
287 |
cnn_layers_init_orthogonal: false
|
288 |
algo_hyperparams: &procgen-algo-defaults
|
@@ -368,7 +488,7 @@ procgen-starpilot-hard-2xIMPALA-fat:
|
|
368 |
policy_hyperparams:
|
369 |
<<: *procgen-policy-defaults
|
370 |
impala_channels: [32, 64, 64]
|
371 |
-
|
372 |
algo_hyperparams:
|
373 |
<<: *procgen-hard-algo-defaults
|
374 |
learning_rate: !!float 2.5e-4
|
|
|
112 |
init_layers_orthogonal: false
|
113 |
activation_fn: relu
|
114 |
share_features_extractor: false
|
115 |
+
cnn_flatten_dim: 256
|
116 |
hidden_sizes: [256]
|
117 |
algo_hyperparams:
|
118 |
n_steps: 512
|
|
|
152 |
vec_env_class: async
|
153 |
policy_hyperparams: &atari-policy-defaults
|
154 |
activation_fn: relu
|
155 |
+
algo_hyperparams: &atari-algo-defaults
|
156 |
n_steps: 128
|
157 |
batch_size: 256
|
158 |
n_epochs: 4
|
|
|
192 |
policy_hyperparams:
|
193 |
<<: *atari-policy-defaults
|
194 |
cnn_style: impala
|
195 |
+
cnn_flatten_dim: 256
|
196 |
init_layers_orthogonal: true
|
197 |
cnn_layers_init_orthogonal: false
|
198 |
|
|
|
212 |
<<: *impala-atari-defaults
|
213 |
env_id: QbertNoFrameskip-v4
|
214 |
|
215 |
+
_microrts: µrts-defaults
|
216 |
+
<<: *atari-defaults
|
217 |
+
n_timesteps: !!float 2e6
|
218 |
+
env_hyperparams: µrts-env-defaults
|
219 |
+
n_envs: 8
|
220 |
+
vec_env_class: sync
|
221 |
+
mask_actions: true
|
222 |
+
policy_hyperparams: µrts-policy-defaults
|
223 |
+
<<: *atari-policy-defaults
|
224 |
+
cnn_style: microrts
|
225 |
+
cnn_flatten_dim: 128
|
226 |
+
algo_hyperparams: µrts-algo-defaults
|
227 |
+
<<: *atari-algo-defaults
|
228 |
+
clip_range_decay: none
|
229 |
+
clip_range_vf: 0.1
|
230 |
+
ppo2_vf_coef_halving: true
|
231 |
+
eval_params:
|
232 |
+
deterministic: false # Good idea because MultiCategorical mode isn't great
|
233 |
+
|
234 |
+
_no-mask-microrts: &no-mask-microrts-defaults
|
235 |
+
<<: *microrts-defaults
|
236 |
+
env_hyperparams:
|
237 |
+
<<: *microrts-env-defaults
|
238 |
+
mask_actions: false
|
239 |
+
|
240 |
+
MicrortsMining-v1-NoMask:
|
241 |
+
<<: *no-mask-microrts-defaults
|
242 |
+
env_id: MicrortsMining-v1
|
243 |
+
|
244 |
+
MicrortsAttackShapedReward-v1-NoMask:
|
245 |
+
<<: *no-mask-microrts-defaults
|
246 |
+
env_id: MicrortsAttackShapedReward-v1
|
247 |
+
|
248 |
+
MicrortsRandomEnemyShapedReward3-v1-NoMask:
|
249 |
+
<<: *no-mask-microrts-defaults
|
250 |
+
env_id: MicrortsRandomEnemyShapedReward3-v1
|
251 |
+
|
252 |
+
_microrts_ai: µrts-ai-defaults
|
253 |
+
<<: *microrts-defaults
|
254 |
+
n_timesteps: !!float 100e6
|
255 |
+
additional_keys_to_log: ["microrts_stats"]
|
256 |
+
env_hyperparams: µrts-ai-env-defaults
|
257 |
+
n_envs: 24
|
258 |
+
env_type: microrts
|
259 |
+
make_kwargs:
|
260 |
+
num_selfplay_envs: 0
|
261 |
+
max_steps: 2000
|
262 |
+
render_theme: 2
|
263 |
+
map_path: maps/16x16/basesWorkers16x16.xml
|
264 |
+
reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
|
265 |
+
policy_hyperparams: µrts-ai-policy-defaults
|
266 |
+
<<: *microrts-policy-defaults
|
267 |
+
cnn_flatten_dim: 256
|
268 |
+
actor_head_style: gridnet
|
269 |
+
algo_hyperparams: µrts-ai-algo-defaults
|
270 |
+
<<: *microrts-algo-defaults
|
271 |
+
learning_rate: !!float 2.5e-4
|
272 |
+
learning_rate_decay: linear
|
273 |
+
n_steps: 512
|
274 |
+
batch_size: 3072
|
275 |
+
n_epochs: 4
|
276 |
+
ent_coef: 0.01
|
277 |
+
vf_coef: 0.5
|
278 |
+
max_grad_norm: 0.5
|
279 |
+
clip_range: 0.1
|
280 |
+
clip_range_vf: 0.1
|
281 |
+
|
282 |
+
MicrortsAttackPassiveEnemySparseReward-v3:
|
283 |
+
<<: *microrts-ai-defaults
|
284 |
+
n_timesteps: !!float 2e6
|
285 |
+
env_id: MicrortsAttackPassiveEnemySparseReward-v3 # Workaround to keep model name simple
|
286 |
+
env_hyperparams:
|
287 |
+
<<: *microrts-ai-env-defaults
|
288 |
+
bots:
|
289 |
+
passiveAI: 24
|
290 |
+
|
291 |
+
MicrortsDefeatRandomEnemySparseReward-v3: µrts-random-ai-defaults
|
292 |
+
<<: *microrts-ai-defaults
|
293 |
+
n_timesteps: !!float 2e6
|
294 |
+
env_id: MicrortsDefeatRandomEnemySparseReward-v3 # Workaround to keep model name simple
|
295 |
+
env_hyperparams:
|
296 |
+
<<: *microrts-ai-env-defaults
|
297 |
+
bots:
|
298 |
+
randomBiasedAI: 24
|
299 |
+
|
300 |
+
enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
|
301 |
+
<<: *microrts-random-ai-defaults
|
302 |
+
policy_hyperparams:
|
303 |
+
<<: *microrts-ai-policy-defaults
|
304 |
+
cnn_style: gridnet_encoder
|
305 |
+
actor_head_style: gridnet_decoder
|
306 |
+
v_hidden_sizes: [128]
|
307 |
+
|
308 |
+
MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults
|
309 |
+
<<: *microrts-ai-defaults
|
310 |
+
env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
|
311 |
+
n_timesteps: !!float 300e6
|
312 |
+
env_hyperparams: µrts-coacai-env-defaults
|
313 |
+
<<: *microrts-ai-env-defaults
|
314 |
+
bots:
|
315 |
+
coacAI: 24
|
316 |
+
|
317 |
+
MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults
|
318 |
+
<<: *microrts-coacai-defaults
|
319 |
+
env_hyperparams:
|
320 |
+
<<: *microrts-coacai-env-defaults
|
321 |
+
bots:
|
322 |
+
coacAI: 18
|
323 |
+
randomBiasedAI: 2
|
324 |
+
lightRushAI: 2
|
325 |
+
workerRushAI: 2
|
326 |
+
|
327 |
+
enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
328 |
+
<<: *microrts-diverse-defaults
|
329 |
+
policy_hyperparams:
|
330 |
+
<<: *microrts-ai-policy-defaults
|
331 |
+
cnn_style: gridnet_encoder
|
332 |
+
actor_head_style: gridnet_decoder
|
333 |
+
v_hidden_sizes: [128]
|
334 |
+
|
335 |
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
336 |
n_timesteps: !!float 2e6
|
337 |
env_hyperparams: &pybullet-env-defaults
|
|
|
402 |
policy_hyperparams: &procgen-policy-defaults
|
403 |
activation_fn: relu
|
404 |
cnn_style: impala
|
405 |
+
cnn_flatten_dim: 256
|
406 |
init_layers_orthogonal: true
|
407 |
cnn_layers_init_orthogonal: false
|
408 |
algo_hyperparams: &procgen-algo-defaults
|
|
|
488 |
policy_hyperparams:
|
489 |
<<: *procgen-policy-defaults
|
490 |
impala_channels: [32, 64, 64]
|
491 |
+
cnn_flatten_dim: 512
|
492 |
algo_hyperparams:
|
493 |
<<: *procgen-hard-algo-defaults
|
494 |
learning_rate: !!float 2.5e-4
|
rl_algo_impls/hyperparams/vpg.yml
CHANGED
@@ -110,7 +110,7 @@ CarRacing-v0:
|
|
110 |
log_std_init: -2
|
111 |
init_layers_orthogonal: false
|
112 |
activation_fn: relu
|
113 |
-
|
114 |
hidden_sizes: [256]
|
115 |
algo_hyperparams:
|
116 |
n_steps: 1000
|
@@ -175,9 +175,9 @@ FrozenLake-v1:
|
|
175 |
save_best: true
|
176 |
|
177 |
_atari: &atari-defaults
|
178 |
-
n_timesteps: !!float
|
179 |
env_hyperparams:
|
180 |
-
n_envs:
|
181 |
frame_stack: 4
|
182 |
no_reward_timeout_steps: 1000
|
183 |
no_reward_fire_steps: 500
|
@@ -185,7 +185,7 @@ _atari: &atari-defaults
|
|
185 |
policy_hyperparams:
|
186 |
activation_fn: relu
|
187 |
algo_hyperparams:
|
188 |
-
n_steps:
|
189 |
pi_lr: !!float 5e-5
|
190 |
gamma: 0.99
|
191 |
gae_lambda: 0.95
|
|
|
110 |
log_std_init: -2
|
111 |
init_layers_orthogonal: false
|
112 |
activation_fn: relu
|
113 |
+
cnn_flatten_dim: 256
|
114 |
hidden_sizes: [256]
|
115 |
algo_hyperparams:
|
116 |
n_steps: 1000
|
|
|
175 |
save_best: true
|
176 |
|
177 |
_atari: &atari-defaults
|
178 |
+
n_timesteps: !!float 10e6
|
179 |
env_hyperparams:
|
180 |
+
n_envs: 2
|
181 |
frame_stack: 4
|
182 |
no_reward_timeout_steps: 1000
|
183 |
no_reward_fire_steps: 500
|
|
|
185 |
policy_hyperparams:
|
186 |
activation_fn: relu
|
187 |
algo_hyperparams:
|
188 |
+
n_steps: 3072
|
189 |
pi_lr: !!float 5e-5
|
190 |
gamma: 0.99
|
191 |
gae_lambda: 0.95
|
rl_algo_impls/optimize.py
CHANGED
@@ -17,7 +17,7 @@ from typing import Callable, List, NamedTuple, Optional, Sequence, Union
|
|
17 |
|
18 |
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
|
19 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
20 |
-
from rl_algo_impls.
|
21 |
from rl_algo_impls.runner.running_utils import (
|
22 |
base_parser,
|
23 |
load_hyperparams,
|
@@ -194,7 +194,7 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
|
|
194 |
env = make_env(
|
195 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
196 |
)
|
197 |
-
device = get_device(config
|
198 |
policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
|
199 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
200 |
|
@@ -274,7 +274,7 @@ def stepwise_optimize(
|
|
274 |
project=study_args.wandb_project_name,
|
275 |
entity=study_args.wandb_entity,
|
276 |
config=asdict(hyperparams),
|
277 |
-
name=f"{
|
278 |
tags=study_args.wandb_tags,
|
279 |
group=study_args.wandb_group,
|
280 |
save_code=True,
|
@@ -298,7 +298,7 @@ def stepwise_optimize(
|
|
298 |
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
299 |
tb_writer=tb_writer,
|
300 |
)
|
301 |
-
device = get_device(config
|
302 |
policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
|
303 |
if i > 0:
|
304 |
policy.load(config.model_dir_path())
|
@@ -433,6 +433,7 @@ def optimize() -> None:
|
|
433 |
|
434 |
fig1 = plot_optimization_history(study)
|
435 |
fig1.write_image("opt_history.png")
|
|
|
436 |
fig2 = plot_param_importances(study)
|
437 |
fig2.write_image("param_importances.png")
|
438 |
|
|
|
17 |
|
18 |
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
|
19 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
20 |
+
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
21 |
from rl_algo_impls.runner.running_utils import (
|
22 |
base_parser,
|
23 |
load_hyperparams,
|
|
|
194 |
env = make_env(
|
195 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
196 |
)
|
197 |
+
device = get_device(config, env)
|
198 |
policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
|
199 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
200 |
|
|
|
274 |
project=study_args.wandb_project_name,
|
275 |
entity=study_args.wandb_entity,
|
276 |
config=asdict(hyperparams),
|
277 |
+
name=f"{str(trial.number)}-S{base_config.seed()}",
|
278 |
tags=study_args.wandb_tags,
|
279 |
group=study_args.wandb_group,
|
280 |
save_code=True,
|
|
|
298 |
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
299 |
tb_writer=tb_writer,
|
300 |
)
|
301 |
+
device = get_device(config, env)
|
302 |
policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
|
303 |
if i > 0:
|
304 |
policy.load(config.model_dir_path())
|
|
|
433 |
|
434 |
fig1 = plot_optimization_history(study)
|
435 |
fig1.write_image("opt_history.png")
|
436 |
+
|
437 |
fig2 = plot_param_importances(study)
|
438 |
fig2.write_image("param_importances.png")
|
439 |
|
rl_algo_impls/ppo/ppo.py
CHANGED
@@ -1,59 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
-
|
5 |
-
from dataclasses import asdict, dataclass, field
|
6 |
-
from time import perf_counter
|
7 |
from torch.optim import Adam
|
8 |
from torch.utils.tensorboard.writer import SummaryWriter
|
9 |
-
from typing import List, Optional, NamedTuple, TypeVar
|
10 |
|
11 |
from rl_algo_impls.shared.algorithm import Algorithm
|
12 |
from rl_algo_impls.shared.callbacks.callback import Callback
|
13 |
-
from rl_algo_impls.shared.gae import
|
14 |
from rl_algo_impls.shared.policy.on_policy import ActorCritic
|
15 |
-
from rl_algo_impls.shared.schedule import
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
)
|
20 |
-
from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
|
21 |
-
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
|
22 |
-
|
23 |
-
|
24 |
-
@dataclass
|
25 |
-
class PPOTrajectory(Trajectory):
|
26 |
-
logp_a: List[float] = field(default_factory=list)
|
27 |
-
|
28 |
-
def add(
|
29 |
-
self,
|
30 |
-
obs: np.ndarray,
|
31 |
-
act: np.ndarray,
|
32 |
-
next_obs: np.ndarray,
|
33 |
-
rew: float,
|
34 |
-
terminated: bool,
|
35 |
-
v: float,
|
36 |
-
logp_a: float,
|
37 |
-
):
|
38 |
-
super().add(obs, act, next_obs, rew, terminated, v)
|
39 |
-
self.logp_a.append(logp_a)
|
40 |
-
|
41 |
-
|
42 |
-
class PPOTrajectoryAccumulator(TrajectoryAccumulator):
|
43 |
-
def __init__(self, num_envs: int) -> None:
|
44 |
-
super().__init__(num_envs, PPOTrajectory)
|
45 |
-
|
46 |
-
def step(
|
47 |
-
self,
|
48 |
-
obs: VecEnvObs,
|
49 |
-
action: np.ndarray,
|
50 |
-
next_obs: VecEnvObs,
|
51 |
-
reward: np.ndarray,
|
52 |
-
done: np.ndarray,
|
53 |
-
val: np.ndarray,
|
54 |
-
logp_a: np.ndarray,
|
55 |
-
) -> None:
|
56 |
-
super().step(obs, action, next_obs, reward, done, val, logp_a)
|
57 |
|
58 |
|
59 |
class TrainStepStats(NamedTuple):
|
@@ -132,39 +99,31 @@ class PPO(Algorithm):
|
|
132 |
vf_coef: float = 0.5,
|
133 |
ppo2_vf_coef_halving: bool = False,
|
134 |
max_grad_norm: float = 0.5,
|
135 |
-
update_rtg_between_epochs: bool = False,
|
136 |
sde_sample_freq: int = -1,
|
|
|
|
|
137 |
) -> None:
|
138 |
super().__init__(policy, env, device, tb_writer)
|
139 |
self.policy = policy
|
|
|
140 |
|
141 |
self.gamma = gamma
|
142 |
self.gae_lambda = gae_lambda
|
143 |
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
|
144 |
-
self.lr_schedule = (
|
145 |
-
linear_schedule(learning_rate, 0)
|
146 |
-
if learning_rate_decay == "linear"
|
147 |
-
else constant_schedule(learning_rate)
|
148 |
-
)
|
149 |
self.max_grad_norm = max_grad_norm
|
150 |
-
self.clip_range_schedule = (
|
151 |
-
linear_schedule(clip_range, 0)
|
152 |
-
if clip_range_decay == "linear"
|
153 |
-
else constant_schedule(clip_range)
|
154 |
-
)
|
155 |
self.clip_range_vf_schedule = None
|
156 |
if clip_range_vf:
|
157 |
-
self.clip_range_vf_schedule = (
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
162 |
self.normalize_advantage = normalize_advantage
|
163 |
-
|
164 |
-
|
165 |
-
if ent_coef_decay == "linear"
|
166 |
-
else constant_schedule(ent_coef)
|
167 |
-
)
|
168 |
self.vf_coef = vf_coef
|
169 |
self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
|
170 |
|
@@ -173,181 +132,243 @@ class PPO(Algorithm):
|
|
173 |
self.n_epochs = n_epochs
|
174 |
self.sde_sample_freq = sde_sample_freq
|
175 |
|
176 |
-
self.
|
|
|
177 |
|
178 |
def learn(
|
179 |
self: PPOSelf,
|
180 |
-
|
181 |
callback: Optional[Callback] = None,
|
|
|
|
|
182 |
) -> PPOSelf:
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
)
|
199 |
-
if
|
200 |
-
|
201 |
-
|
202 |
-
return self
|
203 |
-
|
204 |
-
def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
|
205 |
-
self.policy.eval()
|
206 |
-
accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
|
207 |
-
self.policy.reset_noise()
|
208 |
-
for i in range(self.n_steps):
|
209 |
-
if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
|
210 |
-
self.policy.reset_noise()
|
211 |
-
action, value, logp_a, clamped_action = self.policy.step(obs)
|
212 |
-
next_obs, reward, done, _ = self.env.step(clamped_action)
|
213 |
-
accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
|
214 |
-
obs = next_obs
|
215 |
-
return accumulator
|
216 |
-
|
217 |
-
def train(
|
218 |
-
self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
|
219 |
-
) -> TrainStats:
|
220 |
-
self.policy.train()
|
221 |
-
learning_rate = self.lr_schedule(progress)
|
222 |
-
update_learning_rate(self.optimizer, learning_rate)
|
223 |
-
self.tb_writer.add_scalar(
|
224 |
-
"charts/learning_rate",
|
225 |
-
self.optimizer.param_groups[0]["lr"],
|
226 |
-
timesteps_elapsed,
|
227 |
)
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
v_clip = self.clip_range_vf_schedule(progress)
|
233 |
-
self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
|
234 |
-
else:
|
235 |
-
v_clip = None
|
236 |
-
ent_coef = self.ent_coef_schedule(progress)
|
237 |
-
self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
|
238 |
-
|
239 |
-
obs = torch.as_tensor(
|
240 |
-
np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
|
241 |
-
)
|
242 |
-
act = torch.as_tensor(
|
243 |
-
np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
|
244 |
-
)
|
245 |
-
rtg, adv = compute_rtg_and_advantage(
|
246 |
-
trajectories, self.policy, self.gamma, self.gae_lambda, self.device
|
247 |
-
)
|
248 |
-
orig_v = torch.as_tensor(
|
249 |
-
np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
|
250 |
-
)
|
251 |
-
orig_logp_a = torch.as_tensor(
|
252 |
-
np.concatenate([np.array(t.logp_a) for t in trajectories]),
|
253 |
-
device=self.device,
|
254 |
-
)
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
else:
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
)
|
267 |
-
|
268 |
-
|
269 |
-
mb_idxs = idxs[i : i + self.batch_size]
|
270 |
-
mb_adv = adv[mb_idxs]
|
271 |
-
if self.normalize_advantage:
|
272 |
-
mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
|
273 |
-
self.policy.reset_noise(self.batch_size)
|
274 |
-
step_stats.append(
|
275 |
-
self._train_step(
|
276 |
-
pi_clip,
|
277 |
-
v_clip,
|
278 |
-
ent_coef,
|
279 |
-
obs[mb_idxs],
|
280 |
-
act[mb_idxs],
|
281 |
-
rtg[mb_idxs],
|
282 |
-
mb_adv,
|
283 |
-
orig_v[mb_idxs],
|
284 |
-
orig_logp_a[mb_idxs],
|
285 |
-
)
|
286 |
)
|
287 |
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
-
|
|
|
|
|
|
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
)
|
339 |
-
|
340 |
-
|
341 |
-
if v_clip
|
342 |
-
else 0
|
343 |
)
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from time import perf_counter
|
4 |
+
from typing import List, NamedTuple, Optional, TypeVar
|
5 |
+
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
import torch.nn as nn
|
|
|
|
|
|
|
9 |
from torch.optim import Adam
|
10 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
11 |
|
12 |
from rl_algo_impls.shared.algorithm import Algorithm
|
13 |
from rl_algo_impls.shared.callbacks.callback import Callback
|
14 |
+
from rl_algo_impls.shared.gae import compute_advantages
|
15 |
from rl_algo_impls.shared.policy.on_policy import ActorCritic
|
16 |
+
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
|
17 |
+
from rl_algo_impls.shared.stats import log_scalars
|
18 |
+
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
19 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
20 |
+
VecEnv,
|
21 |
+
single_action_space,
|
22 |
+
single_observation_space,
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class TrainStepStats(NamedTuple):
|
|
|
99 |
vf_coef: float = 0.5,
|
100 |
ppo2_vf_coef_halving: bool = False,
|
101 |
max_grad_norm: float = 0.5,
|
|
|
102 |
sde_sample_freq: int = -1,
|
103 |
+
update_advantage_between_epochs: bool = True,
|
104 |
+
update_returns_between_epochs: bool = False,
|
105 |
) -> None:
|
106 |
super().__init__(policy, env, device, tb_writer)
|
107 |
self.policy = policy
|
108 |
+
self.action_masker = find_action_masker(env)
|
109 |
|
110 |
self.gamma = gamma
|
111 |
self.gae_lambda = gae_lambda
|
112 |
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
|
113 |
+
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
|
|
|
|
|
|
|
|
|
114 |
self.max_grad_norm = max_grad_norm
|
115 |
+
self.clip_range_schedule = schedule(clip_range_decay, clip_range)
|
|
|
|
|
|
|
|
|
116 |
self.clip_range_vf_schedule = None
|
117 |
if clip_range_vf:
|
118 |
+
self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
|
119 |
+
|
120 |
+
if normalize_advantage:
|
121 |
+
assert (
|
122 |
+
env.num_envs * n_steps > 1 and batch_size > 1
|
123 |
+
), f"Each minibatch must be larger than 1 to support normalization"
|
124 |
self.normalize_advantage = normalize_advantage
|
125 |
+
|
126 |
+
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
|
|
|
|
|
|
|
127 |
self.vf_coef = vf_coef
|
128 |
self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
|
129 |
|
|
|
132 |
self.n_epochs = n_epochs
|
133 |
self.sde_sample_freq = sde_sample_freq
|
134 |
|
135 |
+
self.update_advantage_between_epochs = update_advantage_between_epochs
|
136 |
+
self.update_returns_between_epochs = update_returns_between_epochs
|
137 |
|
138 |
def learn(
|
139 |
self: PPOSelf,
|
140 |
+
train_timesteps: int,
|
141 |
callback: Optional[Callback] = None,
|
142 |
+
total_timesteps: Optional[int] = None,
|
143 |
+
start_timesteps: int = 0,
|
144 |
) -> PPOSelf:
|
145 |
+
if total_timesteps is None:
|
146 |
+
total_timesteps = train_timesteps
|
147 |
+
assert start_timesteps + train_timesteps <= total_timesteps
|
148 |
+
|
149 |
+
epoch_dim = (self.n_steps, self.env.num_envs)
|
150 |
+
step_dim = (self.env.num_envs,)
|
151 |
+
obs_space = single_observation_space(self.env)
|
152 |
+
act_space = single_action_space(self.env)
|
153 |
+
act_shape = self.policy.action_shape
|
154 |
+
|
155 |
+
next_obs = self.env.reset()
|
156 |
+
next_action_masks = (
|
157 |
+
self.action_masker.action_masks() if self.action_masker else None
|
158 |
+
)
|
159 |
+
next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
|
160 |
+
|
161 |
+
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
|
162 |
+
actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
|
163 |
+
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
164 |
+
episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
|
165 |
+
values = np.zeros(epoch_dim, dtype=np.float32)
|
166 |
+
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
167 |
+
action_masks = (
|
168 |
+
np.zeros(
|
169 |
+
(self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
|
170 |
)
|
171 |
+
if next_action_masks is not None
|
172 |
+
else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
)
|
174 |
|
175 |
+
timesteps_elapsed = start_timesteps
|
176 |
+
while timesteps_elapsed < start_timesteps + train_timesteps:
|
177 |
+
start_time = perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
progress = timesteps_elapsed / total_timesteps
|
180 |
+
ent_coef = self.ent_coef_schedule(progress)
|
181 |
+
learning_rate = self.lr_schedule(progress)
|
182 |
+
update_learning_rate(self.optimizer, learning_rate)
|
183 |
+
pi_clip = self.clip_range_schedule(progress)
|
184 |
+
chart_scalars = {
|
185 |
+
"learning_rate": self.optimizer.param_groups[0]["lr"],
|
186 |
+
"ent_coef": ent_coef,
|
187 |
+
"pi_clip": pi_clip,
|
188 |
+
}
|
189 |
+
if self.clip_range_vf_schedule:
|
190 |
+
v_clip = self.clip_range_vf_schedule(progress)
|
191 |
+
chart_scalars["v_clip"] = v_clip
|
192 |
else:
|
193 |
+
v_clip = None
|
194 |
+
log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
|
195 |
+
|
196 |
+
self.policy.eval()
|
197 |
+
self.policy.reset_noise()
|
198 |
+
for s in range(self.n_steps):
|
199 |
+
timesteps_elapsed += self.env.num_envs
|
200 |
+
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
|
201 |
+
self.policy.reset_noise()
|
202 |
+
|
203 |
+
obs[s] = next_obs
|
204 |
+
episode_starts[s] = next_episode_starts
|
205 |
+
if action_masks is not None:
|
206 |
+
action_masks[s] = next_action_masks
|
207 |
+
|
208 |
+
(
|
209 |
+
actions[s],
|
210 |
+
values[s],
|
211 |
+
logprobs[s],
|
212 |
+
clamped_action,
|
213 |
+
) = self.policy.step(next_obs, action_masks=next_action_masks)
|
214 |
+
next_obs, rewards[s], next_episode_starts, _ = self.env.step(
|
215 |
+
clamped_action
|
216 |
)
|
217 |
+
next_action_masks = (
|
218 |
+
self.action_masker.action_masks() if self.action_masker else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
)
|
220 |
|
221 |
+
self.policy.train()
|
222 |
+
|
223 |
+
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore
|
224 |
+
b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore
|
225 |
+
self.device
|
226 |
+
)
|
227 |
+
b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
|
228 |
+
b_action_masks = (
|
229 |
+
torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore
|
230 |
+
self.device
|
231 |
+
)
|
232 |
+
if action_masks is not None
|
233 |
+
else None
|
234 |
+
)
|
235 |
+
|
236 |
+
y_pred = values.reshape(-1)
|
237 |
+
b_values = torch.tensor(y_pred).to(self.device)
|
238 |
+
|
239 |
+
step_stats = []
|
240 |
+
# Define variables that will definitely be set through the first epoch
|
241 |
+
advantages: np.ndarray = None # type: ignore
|
242 |
+
b_advantages: torch.Tensor = None # type: ignore
|
243 |
+
y_true: np.ndarray = None # type: ignore
|
244 |
+
b_returns: torch.Tensor = None # type: ignore
|
245 |
+
for e in range(self.n_epochs):
|
246 |
+
if e == 0 or self.update_advantage_between_epochs:
|
247 |
+
advantages = compute_advantages(
|
248 |
+
rewards,
|
249 |
+
values,
|
250 |
+
episode_starts,
|
251 |
+
next_episode_starts,
|
252 |
+
next_obs,
|
253 |
+
self.policy,
|
254 |
+
self.gamma,
|
255 |
+
self.gae_lambda,
|
256 |
+
)
|
257 |
+
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
|
258 |
+
if e == 0 or self.update_returns_between_epochs:
|
259 |
+
returns = advantages + values
|
260 |
+
y_true = returns.reshape(-1)
|
261 |
+
b_returns = torch.tensor(y_true).to(self.device)
|
262 |
+
|
263 |
+
b_idxs = torch.randperm(len(b_obs))
|
264 |
+
# Only record last epoch's stats
|
265 |
+
step_stats.clear()
|
266 |
+
for i in range(0, len(b_obs), self.batch_size):
|
267 |
+
self.policy.reset_noise(self.batch_size)
|
268 |
+
|
269 |
+
mb_idxs = b_idxs[i : i + self.batch_size]
|
270 |
+
|
271 |
+
mb_obs = b_obs[mb_idxs]
|
272 |
+
mb_actions = b_actions[mb_idxs]
|
273 |
+
mb_values = b_values[mb_idxs]
|
274 |
+
mb_logprobs = b_logprobs[mb_idxs]
|
275 |
+
mb_action_masks = (
|
276 |
+
b_action_masks[mb_idxs] if b_action_masks is not None else None
|
277 |
+
)
|
278 |
|
279 |
+
mb_adv = b_advantages[mb_idxs]
|
280 |
+
if self.normalize_advantage:
|
281 |
+
mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
|
282 |
+
mb_returns = b_returns[mb_idxs]
|
283 |
|
284 |
+
new_logprobs, entropy, new_values = self.policy(
|
285 |
+
mb_obs, mb_actions, action_masks=mb_action_masks
|
286 |
+
)
|
287 |
+
|
288 |
+
logratio = new_logprobs - mb_logprobs
|
289 |
+
ratio = torch.exp(logratio)
|
290 |
+
clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
|
291 |
+
pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean()
|
292 |
+
|
293 |
+
v_loss_unclipped = (new_values - mb_returns) ** 2
|
294 |
+
if v_clip:
|
295 |
+
v_loss_clipped = (
|
296 |
+
mb_values
|
297 |
+
+ torch.clamp(new_values - mb_values, -v_clip, v_clip)
|
298 |
+
- mb_returns
|
299 |
+
) ** 2
|
300 |
+
v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
301 |
+
else:
|
302 |
+
v_loss = v_loss_unclipped.mean()
|
303 |
+
|
304 |
+
if self.ppo2_vf_coef_halving:
|
305 |
+
v_loss *= 0.5
|
306 |
+
|
307 |
+
entropy_loss = -entropy.mean()
|
308 |
+
|
309 |
+
loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
|
310 |
+
|
311 |
+
self.optimizer.zero_grad()
|
312 |
+
loss.backward()
|
313 |
+
nn.utils.clip_grad_norm_(
|
314 |
+
self.policy.parameters(), self.max_grad_norm
|
315 |
+
)
|
316 |
+
self.optimizer.step()
|
317 |
+
|
318 |
+
with torch.no_grad():
|
319 |
+
approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
|
320 |
+
clipped_frac = (
|
321 |
+
((ratio - 1).abs() > pi_clip)
|
322 |
+
.float()
|
323 |
+
.mean()
|
324 |
+
.cpu()
|
325 |
+
.numpy()
|
326 |
+
.item()
|
327 |
+
)
|
328 |
+
val_clipped_frac = (
|
329 |
+
((new_values - mb_values).abs() > v_clip)
|
330 |
+
.float()
|
331 |
+
.mean()
|
332 |
+
.cpu()
|
333 |
+
.numpy()
|
334 |
+
.item()
|
335 |
+
if v_clip
|
336 |
+
else 0
|
337 |
+
)
|
338 |
+
|
339 |
+
step_stats.append(
|
340 |
+
TrainStepStats(
|
341 |
+
loss.item(),
|
342 |
+
pi_loss.item(),
|
343 |
+
v_loss.item(),
|
344 |
+
entropy_loss.item(),
|
345 |
+
approx_kl,
|
346 |
+
clipped_frac,
|
347 |
+
val_clipped_frac,
|
348 |
+
)
|
349 |
+
)
|
350 |
+
|
351 |
+
var_y = np.var(y_true).item()
|
352 |
+
explained_var = (
|
353 |
+
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
|
354 |
)
|
355 |
+
TrainStats(step_stats, explained_var).write_to_tensorboard(
|
356 |
+
self.tb_writer, timesteps_elapsed
|
|
|
|
|
357 |
)
|
358 |
|
359 |
+
end_time = perf_counter()
|
360 |
+
rollout_steps = self.n_steps * self.env.num_envs
|
361 |
+
self.tb_writer.add_scalar(
|
362 |
+
"train/steps_per_second",
|
363 |
+
rollout_steps / (end_time - start_time),
|
364 |
+
timesteps_elapsed,
|
365 |
+
)
|
366 |
+
|
367 |
+
if callback:
|
368 |
+
if not callback.on_step(timesteps_elapsed=rollout_steps):
|
369 |
+
logging.info(
|
370 |
+
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
371 |
+
)
|
372 |
+
break
|
373 |
+
|
374 |
+
return self
|
rl_algo_impls/runner/config.py
CHANGED
@@ -2,12 +2,10 @@ import dataclasses
|
|
2 |
import inspect
|
3 |
import itertools
|
4 |
import os
|
5 |
-
|
6 |
-
from datetime import datetime
|
7 |
from dataclasses import dataclass
|
|
|
8 |
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
9 |
|
10 |
-
|
11 |
RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
|
12 |
|
13 |
|
@@ -50,6 +48,9 @@ class EnvHyperparams:
|
|
50 |
video_step_interval: Union[int, float] = 1_000_000
|
51 |
initial_steps_to_truncate: Optional[int] = None
|
52 |
clip_atari_rewards: bool = True
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
@@ -64,6 +65,7 @@ class Hyperparams:
|
|
64 |
algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
65 |
eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
66 |
env_id: Optional[str] = None
|
|
|
67 |
|
68 |
@classmethod
|
69 |
def from_dict_with_extra_fields(
|
@@ -119,6 +121,10 @@ class Config:
|
|
119 |
def env_id(self) -> str:
|
120 |
return self.hyperparams.env_id or self.args.env
|
121 |
|
|
|
|
|
|
|
|
|
122 |
def model_name(self, include_seed: bool = True) -> str:
|
123 |
# Use arg env name instead of environment name
|
124 |
parts = [self.algo, self.args.env]
|
|
|
2 |
import inspect
|
3 |
import itertools
|
4 |
import os
|
|
|
|
|
5 |
from dataclasses import dataclass
|
6 |
+
from datetime import datetime
|
7 |
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
8 |
|
|
|
9 |
RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
|
10 |
|
11 |
|
|
|
48 |
video_step_interval: Union[int, float] = 1_000_000
|
49 |
initial_steps_to_truncate: Optional[int] = None
|
50 |
clip_atari_rewards: bool = True
|
51 |
+
normalize_type: Optional[str] = None
|
52 |
+
mask_actions: bool = False
|
53 |
+
bots: Optional[Dict[str, int]] = None
|
54 |
|
55 |
|
56 |
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
|
|
65 |
algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
66 |
eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
67 |
env_id: Optional[str] = None
|
68 |
+
additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
|
69 |
|
70 |
@classmethod
|
71 |
def from_dict_with_extra_fields(
|
|
|
121 |
def env_id(self) -> str:
|
122 |
return self.hyperparams.env_id or self.args.env
|
123 |
|
124 |
+
@property
|
125 |
+
def additional_keys_to_log(self) -> List[str]:
|
126 |
+
return self.hyperparams.additional_keys_to_log
|
127 |
+
|
128 |
def model_name(self, include_seed: bool = True) -> str:
|
129 |
# Use arg env name instead of environment name
|
130 |
parts = [self.algo, self.args.env]
|
rl_algo_impls/runner/evaluate.py
CHANGED
@@ -4,7 +4,7 @@ import shutil
|
|
4 |
from dataclasses import dataclass
|
5 |
from typing import NamedTuple, Optional
|
6 |
|
7 |
-
from rl_algo_impls.
|
8 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
9 |
from rl_algo_impls.runner.running_utils import (
|
10 |
load_hyperparams,
|
@@ -75,7 +75,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
|
75 |
render=args.render,
|
76 |
normalize_load_path=model_path,
|
77 |
)
|
78 |
-
device = get_device(config
|
79 |
policy = make_policy(
|
80 |
args.algo,
|
81 |
env,
|
|
|
4 |
from dataclasses import dataclass
|
5 |
from typing import NamedTuple, Optional
|
6 |
|
7 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
8 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
9 |
from rl_algo_impls.runner.running_utils import (
|
10 |
load_hyperparams,
|
|
|
75 |
render=args.render,
|
76 |
normalize_load_path=model_path,
|
77 |
)
|
78 |
+
device = get_device(config, env)
|
79 |
policy = make_policy(
|
80 |
args.algo,
|
81 |
env,
|
rl_algo_impls/runner/running_utils.py
CHANGED
@@ -1,32 +1,32 @@
|
|
1 |
import argparse
|
2 |
-
import gym
|
3 |
import json
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import numpy as np
|
6 |
import os
|
7 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import torch
|
9 |
import torch.backends.cudnn
|
10 |
import yaml
|
11 |
-
|
12 |
-
from dataclasses import asdict
|
13 |
from gym.spaces import Box, Discrete
|
14 |
-
from pathlib import Path
|
15 |
from torch.utils.tensorboard.writer import SummaryWriter
|
16 |
-
from typing import Dict, Optional, Type, Union
|
17 |
-
|
18 |
-
from rl_algo_impls.runner.config import Hyperparams
|
19 |
-
from rl_algo_impls.shared.algorithm import Algorithm
|
20 |
-
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
21 |
-
from rl_algo_impls.shared.policy.on_policy import ActorCritic
|
22 |
-
from rl_algo_impls.shared.policy.policy import Policy
|
23 |
|
24 |
from rl_algo_impls.a2c.a2c import A2C
|
25 |
from rl_algo_impls.dqn.dqn import DQN
|
26 |
from rl_algo_impls.dqn.policy import DQNPolicy
|
27 |
from rl_algo_impls.ppo.ppo import PPO
|
28 |
-
from rl_algo_impls.
|
|
|
|
|
|
|
|
|
|
|
29 |
from rl_algo_impls.vpg.policy import VPGActorCritic
|
|
|
30 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
|
31 |
|
32 |
ALGOS: Dict[str, Type[Algorithm]] = {
|
@@ -81,16 +81,19 @@ def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
|
|
81 |
if env_id in hyperparams_dict:
|
82 |
return Hyperparams(**hyperparams_dict[env_id])
|
83 |
|
84 |
-
|
85 |
-
import pybullet_envs
|
86 |
spec = gym.spec(env_id)
|
87 |
-
|
|
|
88 |
return Hyperparams(**hyperparams_dict["_atari"])
|
|
|
|
|
89 |
else:
|
90 |
raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
|
91 |
|
92 |
|
93 |
-
def get_device(
|
|
|
94 |
# cuda by default
|
95 |
if device == "auto":
|
96 |
device = "cuda"
|
@@ -108,6 +111,16 @@ def get_device(device: str, env: VecEnv) -> torch.device:
|
|
108 |
device = "cpu"
|
109 |
elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
|
110 |
device = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
print(f"Device: {device}")
|
112 |
return torch.device(device)
|
113 |
|
@@ -187,6 +200,8 @@ def hparam_dict(
|
|
187 |
flattened[key] = str(sv)
|
188 |
else:
|
189 |
flattened[key] = sv
|
|
|
|
|
190 |
else:
|
191 |
flattened[k] = v # type: ignore
|
192 |
return flattened # type: ignore
|
|
|
1 |
import argparse
|
|
|
2 |
import json
|
|
|
|
|
3 |
import os
|
4 |
import random
|
5 |
+
from dataclasses import asdict
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Dict, Optional, Type, Union
|
8 |
+
|
9 |
+
import gym
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
import torch
|
13 |
import torch.backends.cudnn
|
14 |
import yaml
|
|
|
|
|
15 |
from gym.spaces import Box, Discrete
|
|
|
16 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
from rl_algo_impls.a2c.a2c import A2C
|
19 |
from rl_algo_impls.dqn.dqn import DQN
|
20 |
from rl_algo_impls.dqn.policy import DQNPolicy
|
21 |
from rl_algo_impls.ppo.ppo import PPO
|
22 |
+
from rl_algo_impls.runner.config import Config, Hyperparams
|
23 |
+
from rl_algo_impls.shared.algorithm import Algorithm
|
24 |
+
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
25 |
+
from rl_algo_impls.shared.policy.on_policy import ActorCritic
|
26 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
27 |
+
from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
|
28 |
from rl_algo_impls.vpg.policy import VPGActorCritic
|
29 |
+
from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
|
30 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
|
31 |
|
32 |
ALGOS: Dict[str, Type[Algorithm]] = {
|
|
|
81 |
if env_id in hyperparams_dict:
|
82 |
return Hyperparams(**hyperparams_dict[env_id])
|
83 |
|
84 |
+
import_for_env_id(env_id)
|
|
|
85 |
spec = gym.spec(env_id)
|
86 |
+
entry_point_name = str(spec.entry_point) # type: ignore
|
87 |
+
if "AtariEnv" in entry_point_name and "_atari" in hyperparams_dict:
|
88 |
return Hyperparams(**hyperparams_dict["_atari"])
|
89 |
+
elif "gym_microrts" in entry_point_name and "_microrts" in hyperparams_dict:
|
90 |
+
return Hyperparams(**hyperparams_dict["_microrts"])
|
91 |
else:
|
92 |
raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
|
93 |
|
94 |
|
95 |
+
def get_device(config: Config, env: VecEnv) -> torch.device:
|
96 |
+
device = config.device
|
97 |
# cuda by default
|
98 |
if device == "auto":
|
99 |
device = "cuda"
|
|
|
111 |
device = "cpu"
|
112 |
elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
|
113 |
device = "cpu"
|
114 |
+
if is_microrts(config):
|
115 |
+
try:
|
116 |
+
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
|
117 |
+
|
118 |
+
# Models that move more than one unit at a time should use mps
|
119 |
+
if not isinstance(env.unwrapped, MicroRTSGridModeVecEnv):
|
120 |
+
device = "cpu"
|
121 |
+
except ModuleNotFoundError:
|
122 |
+
# Likely on gym_microrts v0.0.2 to match ppo-implementation-details
|
123 |
+
device = "cpu"
|
124 |
print(f"Device: {device}")
|
125 |
return torch.device(device)
|
126 |
|
|
|
200 |
flattened[key] = str(sv)
|
201 |
else:
|
202 |
flattened[key] = sv
|
203 |
+
elif isinstance(v, list):
|
204 |
+
flattened[k] = json.dumps(v)
|
205 |
else:
|
206 |
flattened[k] = v # type: ignore
|
207 |
return flattened # type: ignore
|
rl_algo_impls/runner/train.py
CHANGED
@@ -5,26 +5,26 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
5 |
|
6 |
import dataclasses
|
7 |
import shutil
|
8 |
-
import wandb
|
9 |
-
import yaml
|
10 |
-
|
11 |
from dataclasses import asdict, dataclass
|
12 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
13 |
from typing import Any, Dict, Optional, Sequence
|
14 |
|
15 |
-
|
|
|
|
|
|
|
16 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
17 |
-
from rl_algo_impls.runner.env import make_env, make_eval_env
|
18 |
from rl_algo_impls.runner.running_utils import (
|
19 |
ALGOS,
|
20 |
-
load_hyperparams,
|
21 |
-
set_seeds,
|
22 |
get_device,
|
|
|
|
|
23 |
make_policy,
|
24 |
plot_eval_callback,
|
25 |
-
|
26 |
)
|
|
|
27 |
from rl_algo_impls.shared.stats import EpisodesStats
|
|
|
28 |
|
29 |
|
30 |
@dataclass
|
@@ -65,7 +65,7 @@ def train(args: TrainArgs):
|
|
65 |
env = make_env(
|
66 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
67 |
)
|
68 |
-
device = get_device(config
|
69 |
policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
|
70 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
71 |
|
@@ -94,6 +94,7 @@ def train(args: TrainArgs):
|
|
94 |
if record_best_videos
|
95 |
else None,
|
96 |
best_video_dir=config.best_videos_dir,
|
|
|
97 |
)
|
98 |
algo.learn(config.n_timesteps, callback=callback)
|
99 |
|
|
|
5 |
|
6 |
import dataclasses
|
7 |
import shutil
|
|
|
|
|
|
|
8 |
from dataclasses import asdict, dataclass
|
|
|
9 |
from typing import Any, Dict, Optional, Sequence
|
10 |
|
11 |
+
import yaml
|
12 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
13 |
+
|
14 |
+
import wandb
|
15 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
|
|
16 |
from rl_algo_impls.runner.running_utils import (
|
17 |
ALGOS,
|
|
|
|
|
18 |
get_device,
|
19 |
+
hparam_dict,
|
20 |
+
load_hyperparams,
|
21 |
make_policy,
|
22 |
plot_eval_callback,
|
23 |
+
set_seeds,
|
24 |
)
|
25 |
+
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
26 |
from rl_algo_impls.shared.stats import EpisodesStats
|
27 |
+
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
28 |
|
29 |
|
30 |
@dataclass
|
|
|
65 |
env = make_env(
|
66 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
67 |
)
|
68 |
+
device = get_device(config, env)
|
69 |
policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
|
70 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
71 |
|
|
|
94 |
if record_best_videos
|
95 |
else None,
|
96 |
best_video_dir=config.best_videos_dir,
|
97 |
+
additional_keys_to_log=config.additional_keys_to_log,
|
98 |
)
|
99 |
algo.learn(config.n_timesteps, callback=callback)
|
100 |
|
rl_algo_impls/shared/actor/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
2 |
+
from rl_algo_impls.shared.actor.make_actor import actor_head
|
rl_algo_impls/shared/actor/actor.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import NamedTuple, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.distributions import Distribution
|
8 |
+
|
9 |
+
|
10 |
+
class PiForward(NamedTuple):
|
11 |
+
pi: Distribution
|
12 |
+
logp_a: Optional[torch.Tensor]
|
13 |
+
entropy: Optional[torch.Tensor]
|
14 |
+
|
15 |
+
|
16 |
+
class Actor(nn.Module, ABC):
|
17 |
+
@abstractmethod
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
obs: torch.Tensor,
|
21 |
+
actions: Optional[torch.Tensor] = None,
|
22 |
+
action_masks: Optional[torch.Tensor] = None,
|
23 |
+
) -> PiForward:
|
24 |
+
...
|
25 |
+
|
26 |
+
def sample_weights(self, batch_size: int = 1) -> None:
|
27 |
+
pass
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def action_shape(self) -> Tuple[int, ...]:
|
32 |
+
...
|
33 |
+
|
34 |
+
def pi_forward(
|
35 |
+
self, distribution: Distribution, actions: Optional[torch.Tensor] = None
|
36 |
+
) -> PiForward:
|
37 |
+
logp_a = None
|
38 |
+
entropy = None
|
39 |
+
if actions is not None:
|
40 |
+
logp_a = distribution.log_prob(actions)
|
41 |
+
entropy = distribution.entropy()
|
42 |
+
return PiForward(distribution, logp_a, entropy)
|
rl_algo_impls/shared/actor/categorical.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions import Categorical
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor import Actor, PiForward
|
8 |
+
from rl_algo_impls.shared.module.module import mlp
|
9 |
+
|
10 |
+
|
11 |
+
class MaskedCategorical(Categorical):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
probs=None,
|
15 |
+
logits=None,
|
16 |
+
validate_args=None,
|
17 |
+
mask: Optional[torch.Tensor] = None,
|
18 |
+
):
|
19 |
+
if mask is not None:
|
20 |
+
assert logits is not None, "mask requires logits and not probs"
|
21 |
+
logits = torch.where(mask, logits, -1e8)
|
22 |
+
self.mask = mask
|
23 |
+
super().__init__(probs, logits, validate_args)
|
24 |
+
|
25 |
+
def entropy(self) -> torch.Tensor:
|
26 |
+
if self.mask is None:
|
27 |
+
return super().entropy()
|
28 |
+
# If mask set, then use approximation for entropy
|
29 |
+
p_log_p = self.logits * self.probs # type: ignore
|
30 |
+
masked = torch.where(self.mask, p_log_p, 0)
|
31 |
+
return -masked.sum(-1)
|
32 |
+
|
33 |
+
|
34 |
+
class CategoricalActorHead(Actor):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
act_dim: int,
|
38 |
+
in_dim: int,
|
39 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
40 |
+
activation: Type[nn.Module] = nn.Tanh,
|
41 |
+
init_layers_orthogonal: bool = True,
|
42 |
+
) -> None:
|
43 |
+
super().__init__()
|
44 |
+
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
|
45 |
+
self._fc = mlp(
|
46 |
+
layer_sizes,
|
47 |
+
activation,
|
48 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
49 |
+
final_layer_gain=0.01,
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
obs: torch.Tensor,
|
55 |
+
actions: Optional[torch.Tensor] = None,
|
56 |
+
action_masks: Optional[torch.Tensor] = None,
|
57 |
+
) -> PiForward:
|
58 |
+
logits = self._fc(obs)
|
59 |
+
pi = MaskedCategorical(logits=logits, mask=action_masks)
|
60 |
+
return self.pi_forward(pi, actions)
|
61 |
+
|
62 |
+
@property
|
63 |
+
def action_shape(self) -> Tuple[int, ...]:
|
64 |
+
return ()
|
rl_algo_impls/shared/actor/gaussian.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions import Distribution, Normal
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
8 |
+
from rl_algo_impls.shared.module.module import mlp
|
9 |
+
|
10 |
+
|
11 |
+
class GaussianDistribution(Normal):
|
12 |
+
def log_prob(self, a: torch.Tensor) -> torch.Tensor:
|
13 |
+
return super().log_prob(a).sum(axis=-1)
|
14 |
+
|
15 |
+
def sample(self) -> torch.Tensor:
|
16 |
+
return self.rsample()
|
17 |
+
|
18 |
+
|
19 |
+
class GaussianActorHead(Actor):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
act_dim: int,
|
23 |
+
in_dim: int,
|
24 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
25 |
+
activation: Type[nn.Module] = nn.Tanh,
|
26 |
+
init_layers_orthogonal: bool = True,
|
27 |
+
log_std_init: float = -0.5,
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
self.act_dim = act_dim
|
31 |
+
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
|
32 |
+
self.mu_net = mlp(
|
33 |
+
layer_sizes,
|
34 |
+
activation,
|
35 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
36 |
+
final_layer_gain=0.01,
|
37 |
+
)
|
38 |
+
self.log_std = nn.Parameter(
|
39 |
+
torch.ones(act_dim, dtype=torch.float32) * log_std_init
|
40 |
+
)
|
41 |
+
|
42 |
+
def _distribution(self, obs: torch.Tensor) -> Distribution:
|
43 |
+
mu = self.mu_net(obs)
|
44 |
+
std = torch.exp(self.log_std)
|
45 |
+
return GaussianDistribution(mu, std)
|
46 |
+
|
47 |
+
def forward(
|
48 |
+
self,
|
49 |
+
obs: torch.Tensor,
|
50 |
+
actions: Optional[torch.Tensor] = None,
|
51 |
+
action_masks: Optional[torch.Tensor] = None,
|
52 |
+
) -> PiForward:
|
53 |
+
assert (
|
54 |
+
not action_masks
|
55 |
+
), f"{self.__class__.__name__} does not support action_masks"
|
56 |
+
pi = self._distribution(obs)
|
57 |
+
return self.pi_forward(pi, actions)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def action_shape(self) -> Tuple[int, ...]:
|
61 |
+
return (self.act_dim,)
|
rl_algo_impls/shared/actor/gridnet.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Tuple, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
from torch.distributions import Distribution, constraints
|
8 |
+
|
9 |
+
from rl_algo_impls.shared.actor import Actor, PiForward
|
10 |
+
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.module import mlp
|
13 |
+
|
14 |
+
|
15 |
+
class GridnetDistribution(Distribution):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
map_size: int,
|
19 |
+
action_vec: NDArray[np.int64],
|
20 |
+
logits: torch.Tensor,
|
21 |
+
masks: torch.Tensor,
|
22 |
+
validate_args: Optional[bool] = None,
|
23 |
+
) -> None:
|
24 |
+
self.map_size = map_size
|
25 |
+
self.action_vec = action_vec
|
26 |
+
|
27 |
+
masks = masks.view(-1, masks.shape[-1])
|
28 |
+
split_masks = torch.split(masks[:, 1:], action_vec.tolist(), dim=1)
|
29 |
+
|
30 |
+
grid_logits = logits.reshape(-1, action_vec.sum())
|
31 |
+
split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
|
32 |
+
self.categoricals = [
|
33 |
+
MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
|
34 |
+
for lg, m in zip(split_logits, split_masks)
|
35 |
+
]
|
36 |
+
|
37 |
+
batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
|
38 |
+
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
|
39 |
+
|
40 |
+
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
|
41 |
+
prob_stack = torch.stack(
|
42 |
+
[
|
43 |
+
c.log_prob(a)
|
44 |
+
for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
|
45 |
+
],
|
46 |
+
dim=-1,
|
47 |
+
)
|
48 |
+
logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
|
49 |
+
return logprob.sum(dim=(1, 2))
|
50 |
+
|
51 |
+
def entropy(self) -> torch.Tensor:
|
52 |
+
ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
|
53 |
+
ent = ent.view(-1, self.map_size, len(self.action_vec))
|
54 |
+
return ent.sum(dim=(1, 2))
|
55 |
+
|
56 |
+
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
|
57 |
+
s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
|
58 |
+
return s.view(-1, self.map_size, len(self.action_vec))
|
59 |
+
|
60 |
+
@property
|
61 |
+
def mode(self) -> torch.Tensor:
|
62 |
+
m = torch.stack([c.mode for c in self.categoricals], dim=-1)
|
63 |
+
return m.view(-1, self.map_size, len(self.action_vec))
|
64 |
+
|
65 |
+
@property
|
66 |
+
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
|
67 |
+
# Constraints handled by child distributions in dist
|
68 |
+
return {}
|
69 |
+
|
70 |
+
|
71 |
+
class GridnetActorHead(Actor):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
map_size: int,
|
75 |
+
action_vec: NDArray[np.int64],
|
76 |
+
in_dim: EncoderOutDim,
|
77 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
78 |
+
activation: Type[nn.Module] = nn.ReLU,
|
79 |
+
init_layers_orthogonal: bool = True,
|
80 |
+
) -> None:
|
81 |
+
super().__init__()
|
82 |
+
self.map_size = map_size
|
83 |
+
self.action_vec = action_vec
|
84 |
+
assert isinstance(in_dim, int)
|
85 |
+
layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
|
86 |
+
self._fc = mlp(
|
87 |
+
layer_sizes,
|
88 |
+
activation,
|
89 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
90 |
+
final_layer_gain=0.01,
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
obs: torch.Tensor,
|
96 |
+
actions: Optional[torch.Tensor] = None,
|
97 |
+
action_masks: Optional[torch.Tensor] = None,
|
98 |
+
) -> PiForward:
|
99 |
+
assert (
|
100 |
+
action_masks is not None
|
101 |
+
), f"No mask case unhandled in {self.__class__.__name__}"
|
102 |
+
logits = self._fc(obs)
|
103 |
+
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
104 |
+
return self.pi_forward(pi, actions)
|
105 |
+
|
106 |
+
@property
|
107 |
+
def action_shape(self) -> Tuple[int, ...]:
|
108 |
+
return (self.map_size, len(self.action_vec))
|
rl_algo_impls/shared/actor/gridnet_decoder.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
|
8 |
+
from rl_algo_impls.shared.actor import Actor, PiForward
|
9 |
+
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
10 |
+
from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
|
11 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.module import layer_init
|
13 |
+
|
14 |
+
|
15 |
+
class Transpose(nn.Module):
|
16 |
+
def __init__(self, permutation: Tuple[int, ...]) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.permutation = permutation
|
19 |
+
|
20 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
21 |
+
return x.permute(self.permutation)
|
22 |
+
|
23 |
+
|
24 |
+
class GridnetDecoder(Actor):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
map_size: int,
|
28 |
+
action_vec: NDArray[np.int64],
|
29 |
+
in_dim: EncoderOutDim,
|
30 |
+
activation: Type[nn.Module] = nn.ReLU,
|
31 |
+
init_layers_orthogonal: bool = True,
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.map_size = map_size
|
35 |
+
self.action_vec = action_vec
|
36 |
+
assert isinstance(in_dim, tuple)
|
37 |
+
self.deconv = nn.Sequential(
|
38 |
+
layer_init(
|
39 |
+
nn.ConvTranspose2d(
|
40 |
+
in_dim[0], 128, 3, stride=2, padding=1, output_padding=1
|
41 |
+
),
|
42 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
43 |
+
),
|
44 |
+
activation(),
|
45 |
+
layer_init(
|
46 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
|
47 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
48 |
+
),
|
49 |
+
activation(),
|
50 |
+
layer_init(
|
51 |
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
|
52 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
53 |
+
),
|
54 |
+
activation(),
|
55 |
+
layer_init(
|
56 |
+
nn.ConvTranspose2d(
|
57 |
+
32, action_vec.sum(), 3, stride=2, padding=1, output_padding=1
|
58 |
+
),
|
59 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
60 |
+
std=0.01,
|
61 |
+
),
|
62 |
+
Transpose((0, 2, 3, 1)),
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self,
|
67 |
+
obs: torch.Tensor,
|
68 |
+
actions: Optional[torch.Tensor] = None,
|
69 |
+
action_masks: Optional[torch.Tensor] = None,
|
70 |
+
) -> PiForward:
|
71 |
+
assert (
|
72 |
+
action_masks is not None
|
73 |
+
), f"No mask case unhandled in {self.__class__.__name__}"
|
74 |
+
logits = self.deconv(obs)
|
75 |
+
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
76 |
+
return self.pi_forward(pi, actions)
|
77 |
+
|
78 |
+
@property
|
79 |
+
def action_shape(self) -> Tuple[int, ...]:
|
80 |
+
return (self.map_size, len(self.action_vec))
|
rl_algo_impls/shared/actor/make_actor.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch.nn as nn
|
5 |
+
from gym.spaces import Box, Discrete, MultiDiscrete
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor
|
8 |
+
from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
|
9 |
+
from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
|
10 |
+
from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
|
11 |
+
from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
|
12 |
+
from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
|
13 |
+
from rl_algo_impls.shared.actor.state_dependent_noise import (
|
14 |
+
StateDependentNoiseActorHead,
|
15 |
+
)
|
16 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
17 |
+
|
18 |
+
|
19 |
+
def actor_head(
|
20 |
+
action_space: gym.Space,
|
21 |
+
in_dim: EncoderOutDim,
|
22 |
+
hidden_sizes: Tuple[int, ...],
|
23 |
+
init_layers_orthogonal: bool,
|
24 |
+
activation: Type[nn.Module],
|
25 |
+
log_std_init: float = -0.5,
|
26 |
+
use_sde: bool = False,
|
27 |
+
full_std: bool = True,
|
28 |
+
squash_output: bool = False,
|
29 |
+
actor_head_style: str = "single",
|
30 |
+
) -> Actor:
|
31 |
+
assert not use_sde or isinstance(
|
32 |
+
action_space, Box
|
33 |
+
), "use_sde only valid if Box action_space"
|
34 |
+
assert not squash_output or use_sde, "squash_output only valid if use_sde"
|
35 |
+
if isinstance(action_space, Discrete):
|
36 |
+
assert isinstance(in_dim, int)
|
37 |
+
return CategoricalActorHead(
|
38 |
+
action_space.n, # type: ignore
|
39 |
+
in_dim=in_dim,
|
40 |
+
hidden_sizes=hidden_sizes,
|
41 |
+
activation=activation,
|
42 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
43 |
+
)
|
44 |
+
elif isinstance(action_space, Box):
|
45 |
+
assert isinstance(in_dim, int)
|
46 |
+
if use_sde:
|
47 |
+
return StateDependentNoiseActorHead(
|
48 |
+
action_space.shape[0], # type: ignore
|
49 |
+
in_dim=in_dim,
|
50 |
+
hidden_sizes=hidden_sizes,
|
51 |
+
activation=activation,
|
52 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
53 |
+
log_std_init=log_std_init,
|
54 |
+
full_std=full_std,
|
55 |
+
squash_output=squash_output,
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
return GaussianActorHead(
|
59 |
+
action_space.shape[0], # type: ignore
|
60 |
+
in_dim=in_dim,
|
61 |
+
hidden_sizes=hidden_sizes,
|
62 |
+
activation=activation,
|
63 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
64 |
+
log_std_init=log_std_init,
|
65 |
+
)
|
66 |
+
elif isinstance(action_space, MultiDiscrete):
|
67 |
+
if actor_head_style == "single":
|
68 |
+
return MultiDiscreteActorHead(
|
69 |
+
action_space.nvec, # type: ignore
|
70 |
+
in_dim=in_dim,
|
71 |
+
hidden_sizes=hidden_sizes,
|
72 |
+
activation=activation,
|
73 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
74 |
+
)
|
75 |
+
elif actor_head_style == "gridnet":
|
76 |
+
return GridnetActorHead(
|
77 |
+
action_space.nvec[0], # type: ignore
|
78 |
+
action_space.nvec[1:], # type: ignore
|
79 |
+
in_dim=in_dim,
|
80 |
+
hidden_sizes=hidden_sizes,
|
81 |
+
activation=activation,
|
82 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
83 |
+
)
|
84 |
+
elif actor_head_style == "gridnet_decoder":
|
85 |
+
return GridnetDecoder(
|
86 |
+
action_space.nvec[0], # type: ignore
|
87 |
+
action_space.nvec[1:], # type: ignore
|
88 |
+
in_dim=in_dim,
|
89 |
+
activation=activation,
|
90 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Unsupported action space: {action_space}")
|
rl_algo_impls/shared/actor/multi_discrete.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Tuple, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
from torch.distributions import Distribution, constraints
|
8 |
+
|
9 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
10 |
+
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.module import mlp
|
13 |
+
|
14 |
+
|
15 |
+
class MultiCategorical(Distribution):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
nvec: NDArray[np.int64],
|
19 |
+
probs=None,
|
20 |
+
logits=None,
|
21 |
+
validate_args=None,
|
22 |
+
masks: Optional[torch.Tensor] = None,
|
23 |
+
):
|
24 |
+
# Either probs or logits should be set
|
25 |
+
assert (probs is None) != (logits is None)
|
26 |
+
masks_split = (
|
27 |
+
torch.split(masks, nvec.tolist(), dim=1)
|
28 |
+
if masks is not None
|
29 |
+
else [None] * len(nvec)
|
30 |
+
)
|
31 |
+
if probs:
|
32 |
+
self.dists = [
|
33 |
+
MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
|
34 |
+
for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
|
35 |
+
]
|
36 |
+
param = probs
|
37 |
+
else:
|
38 |
+
assert logits is not None
|
39 |
+
self.dists = [
|
40 |
+
MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
|
41 |
+
for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
|
42 |
+
]
|
43 |
+
param = logits
|
44 |
+
batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
|
45 |
+
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
|
46 |
+
|
47 |
+
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
|
48 |
+
prob_stack = torch.stack(
|
49 |
+
[c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
|
50 |
+
)
|
51 |
+
return prob_stack.sum(dim=-1)
|
52 |
+
|
53 |
+
def entropy(self) -> torch.Tensor:
|
54 |
+
return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)
|
55 |
+
|
56 |
+
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
|
57 |
+
return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def mode(self) -> torch.Tensor:
|
61 |
+
return torch.stack([c.mode for c in self.dists], dim=-1)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
|
65 |
+
# Constraints handled by child distributions in dist
|
66 |
+
return {}
|
67 |
+
|
68 |
+
|
69 |
+
class MultiDiscreteActorHead(Actor):
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
nvec: NDArray[np.int64],
|
73 |
+
in_dim: EncoderOutDim,
|
74 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
75 |
+
activation: Type[nn.Module] = nn.ReLU,
|
76 |
+
init_layers_orthogonal: bool = True,
|
77 |
+
) -> None:
|
78 |
+
super().__init__()
|
79 |
+
self.nvec = nvec
|
80 |
+
assert isinstance(in_dim, int)
|
81 |
+
layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
|
82 |
+
self._fc = mlp(
|
83 |
+
layer_sizes,
|
84 |
+
activation,
|
85 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
86 |
+
final_layer_gain=0.01,
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
obs: torch.Tensor,
|
92 |
+
actions: Optional[torch.Tensor] = None,
|
93 |
+
action_masks: Optional[torch.Tensor] = None,
|
94 |
+
) -> PiForward:
|
95 |
+
logits = self._fc(obs)
|
96 |
+
pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
|
97 |
+
return self.pi_forward(pi, actions)
|
98 |
+
|
99 |
+
@property
|
100 |
+
def action_shape(self) -> Tuple[int, ...]:
|
101 |
+
return (len(self.nvec),)
|
rl_algo_impls/shared/{policy/actor.py → actor/state_dependent_noise.py}
RENAMED
@@ -1,99 +1,13 @@
|
|
1 |
-
import
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
-
from
|
6 |
-
from gym.spaces import Box, Discrete
|
7 |
-
from torch.distributions import Categorical, Distribution, Normal
|
8 |
-
from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
|
9 |
-
|
10 |
from rl_algo_impls.shared.module.module import mlp
|
11 |
|
12 |
|
13 |
-
class PiForward(NamedTuple):
|
14 |
-
pi: Distribution
|
15 |
-
logp_a: Optional[torch.Tensor]
|
16 |
-
entropy: Optional[torch.Tensor]
|
17 |
-
|
18 |
-
|
19 |
-
class Actor(nn.Module, ABC):
|
20 |
-
@abstractmethod
|
21 |
-
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
|
22 |
-
...
|
23 |
-
|
24 |
-
|
25 |
-
class CategoricalActorHead(Actor):
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
act_dim: int,
|
29 |
-
hidden_sizes: Sequence[int] = (32,),
|
30 |
-
activation: Type[nn.Module] = nn.Tanh,
|
31 |
-
init_layers_orthogonal: bool = True,
|
32 |
-
) -> None:
|
33 |
-
super().__init__()
|
34 |
-
layer_sizes = tuple(hidden_sizes) + (act_dim,)
|
35 |
-
self._fc = mlp(
|
36 |
-
layer_sizes,
|
37 |
-
activation,
|
38 |
-
init_layers_orthogonal=init_layers_orthogonal,
|
39 |
-
final_layer_gain=0.01,
|
40 |
-
)
|
41 |
-
|
42 |
-
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
|
43 |
-
logits = self._fc(obs)
|
44 |
-
pi = Categorical(logits=logits)
|
45 |
-
logp_a = None
|
46 |
-
entropy = None
|
47 |
-
if a is not None:
|
48 |
-
logp_a = pi.log_prob(a)
|
49 |
-
entropy = pi.entropy()
|
50 |
-
return PiForward(pi, logp_a, entropy)
|
51 |
-
|
52 |
-
|
53 |
-
class GaussianDistribution(Normal):
|
54 |
-
def log_prob(self, a: torch.Tensor) -> torch.Tensor:
|
55 |
-
return super().log_prob(a).sum(axis=-1)
|
56 |
-
|
57 |
-
def sample(self) -> torch.Tensor:
|
58 |
-
return self.rsample()
|
59 |
-
|
60 |
-
|
61 |
-
class GaussianActorHead(Actor):
|
62 |
-
def __init__(
|
63 |
-
self,
|
64 |
-
act_dim: int,
|
65 |
-
hidden_sizes: Sequence[int] = (32,),
|
66 |
-
activation: Type[nn.Module] = nn.Tanh,
|
67 |
-
init_layers_orthogonal: bool = True,
|
68 |
-
log_std_init: float = -0.5,
|
69 |
-
) -> None:
|
70 |
-
super().__init__()
|
71 |
-
layer_sizes = tuple(hidden_sizes) + (act_dim,)
|
72 |
-
self.mu_net = mlp(
|
73 |
-
layer_sizes,
|
74 |
-
activation,
|
75 |
-
init_layers_orthogonal=init_layers_orthogonal,
|
76 |
-
final_layer_gain=0.01,
|
77 |
-
)
|
78 |
-
self.log_std = nn.Parameter(
|
79 |
-
torch.ones(act_dim, dtype=torch.float32) * log_std_init
|
80 |
-
)
|
81 |
-
|
82 |
-
def _distribution(self, obs: torch.Tensor) -> Distribution:
|
83 |
-
mu = self.mu_net(obs)
|
84 |
-
std = torch.exp(self.log_std)
|
85 |
-
return GaussianDistribution(mu, std)
|
86 |
-
|
87 |
-
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
|
88 |
-
pi = self._distribution(obs)
|
89 |
-
logp_a = None
|
90 |
-
entropy = None
|
91 |
-
if a is not None:
|
92 |
-
logp_a = pi.log_prob(a)
|
93 |
-
entropy = pi.entropy()
|
94 |
-
return PiForward(pi, logp_a, entropy)
|
95 |
-
|
96 |
-
|
97 |
class TanhBijector:
|
98 |
def __init__(self, epsilon: float = 1e-6) -> None:
|
99 |
self.epsilon = epsilon
|
@@ -173,7 +87,8 @@ class StateDependentNoiseActorHead(Actor):
|
|
173 |
def __init__(
|
174 |
self,
|
175 |
act_dim: int,
|
176 |
-
|
|
|
177 |
activation: Type[nn.Module] = nn.Tanh,
|
178 |
init_layers_orthogonal: bool = True,
|
179 |
log_std_init: float = -0.5,
|
@@ -183,7 +98,7 @@ class StateDependentNoiseActorHead(Actor):
|
|
183 |
) -> None:
|
184 |
super().__init__()
|
185 |
self.act_dim = act_dim
|
186 |
-
layer_sizes =
|
187 |
if len(layer_sizes) == 2:
|
188 |
self.latent_net = nn.Identity()
|
189 |
elif len(layer_sizes) > 2:
|
@@ -193,8 +108,6 @@ class StateDependentNoiseActorHead(Actor):
|
|
193 |
output_activation=activation,
|
194 |
init_layers_orthogonal=init_layers_orthogonal,
|
195 |
)
|
196 |
-
else:
|
197 |
-
raise ValueError("hidden_sizes must be of at least length 1")
|
198 |
self.mu_net = mlp(
|
199 |
layer_sizes[-2:],
|
200 |
activation,
|
@@ -202,7 +115,7 @@ class StateDependentNoiseActorHead(Actor):
|
|
202 |
final_layer_gain=0.01,
|
203 |
)
|
204 |
self.full_std = full_std
|
205 |
-
std_dim = (
|
206 |
self.log_std = nn.Parameter(
|
207 |
torch.ones(std_dim, dtype=torch.float32) * log_std_init
|
208 |
)
|
@@ -249,14 +162,17 @@ class StateDependentNoiseActorHead(Actor):
|
|
249 |
ones = ones.to(self.device)
|
250 |
return ones * std
|
251 |
|
252 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
pi = self._distribution(obs)
|
254 |
-
|
255 |
-
entropy = None
|
256 |
-
if a is not None:
|
257 |
-
logp_a = pi.log_prob(a)
|
258 |
-
entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
|
259 |
-
return PiForward(pi, logp_a, entropy)
|
260 |
|
261 |
def sample_weights(self, batch_size: int = 1) -> None:
|
262 |
std = self._get_std()
|
@@ -265,46 +181,20 @@ class StateDependentNoiseActorHead(Actor):
|
|
265 |
self.exploration_mat = weights_dist.rsample()
|
266 |
self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
|
267 |
|
|
|
|
|
|
|
268 |
|
269 |
-
def
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
action_space, Box
|
281 |
-
), "use_sde only valid if Box action_space"
|
282 |
-
assert not squash_output or use_sde, "squash_output only valid if use_sde"
|
283 |
-
if isinstance(action_space, Discrete):
|
284 |
-
return CategoricalActorHead(
|
285 |
-
action_space.n,
|
286 |
-
hidden_sizes=hidden_sizes,
|
287 |
-
activation=activation,
|
288 |
-
init_layers_orthogonal=init_layers_orthogonal,
|
289 |
-
)
|
290 |
-
elif isinstance(action_space, Box):
|
291 |
-
if use_sde:
|
292 |
-
return StateDependentNoiseActorHead(
|
293 |
-
action_space.shape[0],
|
294 |
-
hidden_sizes=hidden_sizes,
|
295 |
-
activation=activation,
|
296 |
-
init_layers_orthogonal=init_layers_orthogonal,
|
297 |
-
log_std_init=log_std_init,
|
298 |
-
full_std=full_std,
|
299 |
-
squash_output=squash_output,
|
300 |
-
)
|
301 |
-
else:
|
302 |
-
return GaussianActorHead(
|
303 |
-
action_space.shape[0],
|
304 |
-
hidden_sizes=hidden_sizes,
|
305 |
-
activation=activation,
|
306 |
-
init_layers_orthogonal=init_layers_orthogonal,
|
307 |
-
log_std_init=log_std_init,
|
308 |
)
|
309 |
-
|
310 |
-
raise ValueError(f"Unsupported action space: {action_space}")
|
|
|
1 |
+
from typing import Optional, Tuple, Type, TypeVar, Union
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
+
from torch.distributions import Distribution, Normal
|
6 |
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
|
|
|
|
|
|
|
|
8 |
from rl_algo_impls.shared.module.module import mlp
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class TanhBijector:
|
12 |
def __init__(self, epsilon: float = 1e-6) -> None:
|
13 |
self.epsilon = epsilon
|
|
|
87 |
def __init__(
|
88 |
self,
|
89 |
act_dim: int,
|
90 |
+
in_dim: int,
|
91 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
92 |
activation: Type[nn.Module] = nn.Tanh,
|
93 |
init_layers_orthogonal: bool = True,
|
94 |
log_std_init: float = -0.5,
|
|
|
98 |
) -> None:
|
99 |
super().__init__()
|
100 |
self.act_dim = act_dim
|
101 |
+
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
|
102 |
if len(layer_sizes) == 2:
|
103 |
self.latent_net = nn.Identity()
|
104 |
elif len(layer_sizes) > 2:
|
|
|
108 |
output_activation=activation,
|
109 |
init_layers_orthogonal=init_layers_orthogonal,
|
110 |
)
|
|
|
|
|
111 |
self.mu_net = mlp(
|
112 |
layer_sizes[-2:],
|
113 |
activation,
|
|
|
115 |
final_layer_gain=0.01,
|
116 |
)
|
117 |
self.full_std = full_std
|
118 |
+
std_dim = (layer_sizes[-2], act_dim if self.full_std else 1)
|
119 |
self.log_std = nn.Parameter(
|
120 |
torch.ones(std_dim, dtype=torch.float32) * log_std_init
|
121 |
)
|
|
|
162 |
ones = ones.to(self.device)
|
163 |
return ones * std
|
164 |
|
165 |
+
def forward(
|
166 |
+
self,
|
167 |
+
obs: torch.Tensor,
|
168 |
+
actions: Optional[torch.Tensor] = None,
|
169 |
+
action_masks: Optional[torch.Tensor] = None,
|
170 |
+
) -> PiForward:
|
171 |
+
assert (
|
172 |
+
not action_masks
|
173 |
+
), f"{self.__class__.__name__} does not support action_masks"
|
174 |
pi = self._distribution(obs)
|
175 |
+
return self.pi_forward(pi, actions)
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def sample_weights(self, batch_size: int = 1) -> None:
|
178 |
std = self._get_std()
|
|
|
181 |
self.exploration_mat = weights_dist.rsample()
|
182 |
self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
|
183 |
|
184 |
+
@property
|
185 |
+
def action_shape(self) -> Tuple[int, ...]:
|
186 |
+
return (self.act_dim,)
|
187 |
|
188 |
+
def pi_forward(
|
189 |
+
self, distribution: Distribution, actions: Optional[torch.Tensor] = None
|
190 |
+
) -> PiForward:
|
191 |
+
logp_a = None
|
192 |
+
entropy = None
|
193 |
+
if actions is not None:
|
194 |
+
logp_a = distribution.log_prob(actions)
|
195 |
+
entropy = (
|
196 |
+
-logp_a
|
197 |
+
if self.bijector
|
198 |
+
else sum_independent_dims(distribution.entropy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
)
|
200 |
+
return PiForward(distribution, logp_a, entropy)
|
|
rl_algo_impls/shared/callbacks/eval_callback.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
import itertools
|
2 |
-
import numpy as np
|
3 |
import os
|
4 |
-
|
5 |
from time import perf_counter
|
|
|
|
|
|
|
6 |
from torch.utils.tensorboard.writer import SummaryWriter
|
7 |
-
from typing import List, Optional, Union
|
8 |
|
9 |
from rl_algo_impls.shared.callbacks.callback import Callback
|
10 |
from rl_algo_impls.shared.policy.policy import Policy
|
11 |
from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
|
|
|
12 |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
13 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
14 |
|
@@ -20,6 +21,7 @@ class EvaluateAccumulator(EpisodeAccumulator):
|
|
20 |
goal_episodes: int,
|
21 |
print_returns: bool = True,
|
22 |
ignore_first_episode: bool = False,
|
|
|
23 |
):
|
24 |
super().__init__(num_envs)
|
25 |
self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
|
@@ -36,8 +38,11 @@ class EvaluateAccumulator(EpisodeAccumulator):
|
|
36 |
self.should_record_done = should_record_done
|
37 |
else:
|
38 |
self.should_record_done = lambda idx: True
|
|
|
39 |
|
40 |
-
def on_done(self, ep_idx: int, episode: Episode) -> None:
|
|
|
|
|
41 |
if (
|
42 |
self.should_record_done(ep_idx)
|
43 |
and len(self.completed_episodes_by_env_idx[ep_idx])
|
@@ -74,19 +79,29 @@ def evaluate(
|
|
74 |
deterministic: bool = True,
|
75 |
print_returns: bool = True,
|
76 |
ignore_first_episode: bool = False,
|
|
|
77 |
) -> EpisodesStats:
|
78 |
policy.sync_normalization(env)
|
79 |
policy.eval()
|
80 |
|
81 |
episodes = EvaluateAccumulator(
|
82 |
-
env.num_envs,
|
|
|
|
|
|
|
|
|
83 |
)
|
84 |
|
85 |
obs = env.reset()
|
|
|
86 |
while not episodes.is_done():
|
87 |
-
act = policy.act(
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
90 |
if render:
|
91 |
env.render()
|
92 |
stats = EpisodesStats(episodes.episodes)
|
@@ -111,6 +126,7 @@ class EvalCallback(Callback):
|
|
111 |
best_video_dir: Optional[str] = None,
|
112 |
max_video_length: int = 3600,
|
113 |
ignore_first_episode: bool = False,
|
|
|
114 |
) -> None:
|
115 |
super().__init__()
|
116 |
self.policy = policy
|
@@ -133,8 +149,8 @@ class EvalCallback(Callback):
|
|
133 |
os.makedirs(best_video_dir, exist_ok=True)
|
134 |
self.max_video_length = max_video_length
|
135 |
self.best_video_base_path = None
|
136 |
-
|
137 |
self.ignore_first_episode = ignore_first_episode
|
|
|
138 |
|
139 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
140 |
super().on_step(timesteps_elapsed)
|
@@ -153,6 +169,7 @@ class EvalCallback(Callback):
|
|
153 |
deterministic=self.deterministic,
|
154 |
print_returns=print_returns or False,
|
155 |
ignore_first_episode=self.ignore_first_episode,
|
|
|
156 |
)
|
157 |
end_time = perf_counter()
|
158 |
self.tb_writer.add_scalar(
|
|
|
1 |
import itertools
|
|
|
2 |
import os
|
|
|
3 |
from time import perf_counter
|
4 |
+
from typing import Dict, List, Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
8 |
|
9 |
from rl_algo_impls.shared.callbacks.callback import Callback
|
10 |
from rl_algo_impls.shared.policy.policy import Policy
|
11 |
from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
|
12 |
+
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
13 |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
14 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
15 |
|
|
|
21 |
goal_episodes: int,
|
22 |
print_returns: bool = True,
|
23 |
ignore_first_episode: bool = False,
|
24 |
+
additional_keys_to_log: Optional[List[str]] = None,
|
25 |
):
|
26 |
super().__init__(num_envs)
|
27 |
self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
|
|
|
38 |
self.should_record_done = should_record_done
|
39 |
else:
|
40 |
self.should_record_done = lambda idx: True
|
41 |
+
self.additional_keys_to_log = additional_keys_to_log
|
42 |
|
43 |
+
def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
|
44 |
+
if self.additional_keys_to_log:
|
45 |
+
episode.info = {k: info[k] for k in self.additional_keys_to_log}
|
46 |
if (
|
47 |
self.should_record_done(ep_idx)
|
48 |
and len(self.completed_episodes_by_env_idx[ep_idx])
|
|
|
79 |
deterministic: bool = True,
|
80 |
print_returns: bool = True,
|
81 |
ignore_first_episode: bool = False,
|
82 |
+
additional_keys_to_log: Optional[List[str]] = None,
|
83 |
) -> EpisodesStats:
|
84 |
policy.sync_normalization(env)
|
85 |
policy.eval()
|
86 |
|
87 |
episodes = EvaluateAccumulator(
|
88 |
+
env.num_envs,
|
89 |
+
n_episodes,
|
90 |
+
print_returns,
|
91 |
+
ignore_first_episode,
|
92 |
+
additional_keys_to_log=additional_keys_to_log,
|
93 |
)
|
94 |
|
95 |
obs = env.reset()
|
96 |
+
action_masker = find_action_masker(env)
|
97 |
while not episodes.is_done():
|
98 |
+
act = policy.act(
|
99 |
+
obs,
|
100 |
+
deterministic=deterministic,
|
101 |
+
action_masks=action_masker.action_masks() if action_masker else None,
|
102 |
+
)
|
103 |
+
obs, rew, done, info = env.step(act)
|
104 |
+
episodes.step(rew, done, info)
|
105 |
if render:
|
106 |
env.render()
|
107 |
stats = EpisodesStats(episodes.episodes)
|
|
|
126 |
best_video_dir: Optional[str] = None,
|
127 |
max_video_length: int = 3600,
|
128 |
ignore_first_episode: bool = False,
|
129 |
+
additional_keys_to_log: Optional[List[str]] = None,
|
130 |
) -> None:
|
131 |
super().__init__()
|
132 |
self.policy = policy
|
|
|
149 |
os.makedirs(best_video_dir, exist_ok=True)
|
150 |
self.max_video_length = max_video_length
|
151 |
self.best_video_base_path = None
|
|
|
152 |
self.ignore_first_episode = ignore_first_episode
|
153 |
+
self.additional_keys_to_log = additional_keys_to_log
|
154 |
|
155 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
156 |
super().on_step(timesteps_elapsed)
|
|
|
169 |
deterministic=self.deterministic,
|
170 |
print_returns=print_returns or False,
|
171 |
ignore_first_episode=self.ignore_first_episode,
|
172 |
+
additional_keys_to_log=self.additional_keys_to_log,
|
173 |
)
|
174 |
end_time = perf_counter()
|
175 |
self.tb_writer.add_scalar(
|
rl_algo_impls/shared/encoder/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.shared.encoder.cnn import EncoderOutDim
|
2 |
+
from rl_algo_impls.shared.encoder.encoder import Encoder
|
rl_algo_impls/shared/encoder/cnn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Optional, Tuple, Type, Union
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from rl_algo_impls.shared.module.module import layer_init
|
10 |
+
|
11 |
+
EncoderOutDim = Union[int, Tuple[int, ...]]
|
12 |
+
|
13 |
+
|
14 |
+
class CnnEncoder(nn.Module, ABC):
|
15 |
+
@abstractmethod
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
obs_space: gym.Space,
|
19 |
+
**kwargs,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self.range_size = np.max(obs_space.high) - np.min(obs_space.low) # type: ignore
|
23 |
+
|
24 |
+
def preprocess(self, obs: torch.Tensor) -> torch.Tensor:
|
25 |
+
if len(obs.shape) == 3:
|
26 |
+
obs = obs.unsqueeze(0)
|
27 |
+
return obs.float() / self.range_size
|
28 |
+
|
29 |
+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
30 |
+
return self.preprocess(obs)
|
31 |
+
|
32 |
+
@property
|
33 |
+
@abstractmethod
|
34 |
+
def out_dim(self) -> EncoderOutDim:
|
35 |
+
...
|
36 |
+
|
37 |
+
|
38 |
+
class FlattenedCnnEncoder(CnnEncoder):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
obs_space: gym.Space,
|
42 |
+
activation: Type[nn.Module],
|
43 |
+
linear_init_layers_orthogonal: bool,
|
44 |
+
cnn_flatten_dim: int,
|
45 |
+
cnn: nn.Module,
|
46 |
+
**kwargs,
|
47 |
+
) -> None:
|
48 |
+
super().__init__(obs_space, **kwargs)
|
49 |
+
self.cnn = cnn
|
50 |
+
self.flattened_dim = cnn_flatten_dim
|
51 |
+
with torch.no_grad():
|
52 |
+
cnn_out = torch.flatten(
|
53 |
+
cnn(self.preprocess(torch.as_tensor(obs_space.sample()))), start_dim=1
|
54 |
+
)
|
55 |
+
self.fc = nn.Sequential(
|
56 |
+
nn.Flatten(),
|
57 |
+
layer_init(
|
58 |
+
nn.Linear(cnn_out.shape[1], cnn_flatten_dim),
|
59 |
+
linear_init_layers_orthogonal,
|
60 |
+
),
|
61 |
+
activation(),
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
65 |
+
x = super().forward(obs)
|
66 |
+
x = self.cnn(x)
|
67 |
+
x = self.fc(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
@property
|
71 |
+
def out_dim(self) -> EncoderOutDim:
|
72 |
+
return self.flattened_dim
|
rl_algo_impls/shared/encoder/encoder.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Sequence, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from gym.spaces import Box, Discrete
|
8 |
+
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
|
9 |
+
|
10 |
+
from rl_algo_impls.shared.encoder.cnn import CnnEncoder
|
11 |
+
from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
|
12 |
+
from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
|
13 |
+
from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
|
14 |
+
from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
|
15 |
+
from rl_algo_impls.shared.module.module import layer_init
|
16 |
+
|
17 |
+
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
|
18 |
+
"nature": NatureCnn,
|
19 |
+
"impala": ImpalaCnn,
|
20 |
+
"microrts": MicrortsCnn,
|
21 |
+
"gridnet_encoder": GridnetEncoder,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class Encoder(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
obs_space: gym.Space,
|
29 |
+
activation: Type[nn.Module],
|
30 |
+
init_layers_orthogonal: bool = False,
|
31 |
+
cnn_flatten_dim: int = 512,
|
32 |
+
cnn_style: str = "nature",
|
33 |
+
cnn_layers_init_orthogonal: Optional[bool] = None,
|
34 |
+
impala_channels: Sequence[int] = (16, 32, 32),
|
35 |
+
) -> None:
|
36 |
+
super().__init__()
|
37 |
+
if isinstance(obs_space, Box):
|
38 |
+
# Conv2D: (channels, height, width)
|
39 |
+
if len(obs_space.shape) == 3: # type: ignore
|
40 |
+
self.preprocess = None
|
41 |
+
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
|
42 |
+
obs_space,
|
43 |
+
activation=activation,
|
44 |
+
cnn_init_layers_orthogonal=cnn_layers_init_orthogonal,
|
45 |
+
linear_init_layers_orthogonal=init_layers_orthogonal,
|
46 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
47 |
+
impala_channels=impala_channels,
|
48 |
+
)
|
49 |
+
self.feature_extractor = cnn
|
50 |
+
self.out_dim = cnn.out_dim
|
51 |
+
elif len(obs_space.shape) == 1: # type: ignore
|
52 |
+
|
53 |
+
def preprocess(obs: torch.Tensor) -> torch.Tensor:
|
54 |
+
if len(obs.shape) == 1:
|
55 |
+
obs = obs.unsqueeze(0)
|
56 |
+
return obs.float()
|
57 |
+
|
58 |
+
self.preprocess = preprocess
|
59 |
+
self.feature_extractor = nn.Flatten()
|
60 |
+
self.out_dim = get_flattened_obs_dim(obs_space)
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unsupported observation space: {obs_space}")
|
63 |
+
elif isinstance(obs_space, Discrete):
|
64 |
+
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
|
65 |
+
self.feature_extractor = nn.Flatten()
|
66 |
+
self.out_dim = obs_space.n # type: ignore
|
67 |
+
else:
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
71 |
+
if self.preprocess:
|
72 |
+
obs = self.preprocess(obs)
|
73 |
+
return self.feature_extractor(obs)
|
rl_algo_impls/shared/encoder/gridnet_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type, Union
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
|
8 |
+
from rl_algo_impls.shared.module.module import layer_init
|
9 |
+
|
10 |
+
|
11 |
+
class GridnetEncoder(CnnEncoder):
|
12 |
+
"""
|
13 |
+
Encoder for encoder-decoder for Gym-MicroRTS
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
obs_space: gym.Space,
|
19 |
+
activation: Type[nn.Module] = nn.ReLU,
|
20 |
+
cnn_init_layers_orthogonal: Optional[bool] = None,
|
21 |
+
**kwargs
|
22 |
+
) -> None:
|
23 |
+
if cnn_init_layers_orthogonal is None:
|
24 |
+
cnn_init_layers_orthogonal = True
|
25 |
+
super().__init__(obs_space, **kwargs)
|
26 |
+
in_channels = obs_space.shape[0] # type: ignore
|
27 |
+
self.encoder = nn.Sequential(
|
28 |
+
layer_init(
|
29 |
+
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
|
30 |
+
cnn_init_layers_orthogonal,
|
31 |
+
),
|
32 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
33 |
+
activation(),
|
34 |
+
layer_init(
|
35 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
36 |
+
cnn_init_layers_orthogonal,
|
37 |
+
),
|
38 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
39 |
+
activation(),
|
40 |
+
layer_init(
|
41 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
42 |
+
cnn_init_layers_orthogonal,
|
43 |
+
),
|
44 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
45 |
+
activation(),
|
46 |
+
layer_init(
|
47 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
48 |
+
cnn_init_layers_orthogonal,
|
49 |
+
),
|
50 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
51 |
+
activation(),
|
52 |
+
)
|
53 |
+
with torch.no_grad():
|
54 |
+
encoder_out = self.encoder(
|
55 |
+
self.preprocess(torch.as_tensor(obs_space.sample())) # type: ignore
|
56 |
+
)
|
57 |
+
self._out_dim = encoder_out.shape[1:]
|
58 |
+
|
59 |
+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
60 |
+
return self.encoder(super().forward(obs))
|
61 |
+
|
62 |
+
@property
|
63 |
+
def out_dim(self) -> EncoderOutDim:
|
64 |
+
return self._out_dim
|
rl_algo_impls/shared/encoder/impala_cnn.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
8 |
+
from rl_algo_impls.shared.module.module import layer_init
|
9 |
+
|
10 |
+
|
11 |
+
class ResidualBlock(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
channels: int,
|
15 |
+
activation: Type[nn.Module] = nn.ReLU,
|
16 |
+
init_layers_orthogonal: bool = False,
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
self.residual = nn.Sequential(
|
20 |
+
activation(),
|
21 |
+
layer_init(
|
22 |
+
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
|
23 |
+
),
|
24 |
+
activation(),
|
25 |
+
layer_init(
|
26 |
+
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
|
27 |
+
),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
31 |
+
return x + self.residual(x)
|
32 |
+
|
33 |
+
|
34 |
+
class ConvSequence(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
in_channels: int,
|
38 |
+
out_channels: int,
|
39 |
+
activation: Type[nn.Module] = nn.ReLU,
|
40 |
+
init_layers_orthogonal: bool = False,
|
41 |
+
) -> None:
|
42 |
+
super().__init__()
|
43 |
+
self.seq = nn.Sequential(
|
44 |
+
layer_init(
|
45 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
46 |
+
init_layers_orthogonal,
|
47 |
+
),
|
48 |
+
nn.MaxPool2d(3, stride=2, padding=1),
|
49 |
+
ResidualBlock(out_channels, activation, init_layers_orthogonal),
|
50 |
+
ResidualBlock(out_channels, activation, init_layers_orthogonal),
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54 |
+
return self.seq(x)
|
55 |
+
|
56 |
+
|
57 |
+
class ImpalaCnn(FlattenedCnnEncoder):
|
58 |
+
"""
|
59 |
+
IMPALA-style CNN architecture
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
obs_space: gym.Space,
|
65 |
+
activation: Type[nn.Module],
|
66 |
+
cnn_init_layers_orthogonal: Optional[bool],
|
67 |
+
linear_init_layers_orthogonal: bool,
|
68 |
+
cnn_flatten_dim: int,
|
69 |
+
impala_channels: Sequence[int] = (16, 32, 32),
|
70 |
+
**kwargs,
|
71 |
+
) -> None:
|
72 |
+
if cnn_init_layers_orthogonal is None:
|
73 |
+
cnn_init_layers_orthogonal = False
|
74 |
+
in_channels = obs_space.shape[0] # type: ignore
|
75 |
+
sequences = []
|
76 |
+
for out_channels in impala_channels:
|
77 |
+
sequences.append(
|
78 |
+
ConvSequence(
|
79 |
+
in_channels, out_channels, activation, cnn_init_layers_orthogonal
|
80 |
+
)
|
81 |
+
)
|
82 |
+
in_channels = out_channels
|
83 |
+
sequences.append(activation())
|
84 |
+
cnn = nn.Sequential(*sequences)
|
85 |
+
super().__init__(
|
86 |
+
obs_space,
|
87 |
+
activation,
|
88 |
+
linear_init_layers_orthogonal,
|
89 |
+
cnn_flatten_dim,
|
90 |
+
cnn,
|
91 |
+
**kwargs,
|
92 |
+
)
|
rl_algo_impls/shared/encoder/microrts_cnn.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
8 |
+
from rl_algo_impls.shared.module.module import layer_init
|
9 |
+
|
10 |
+
|
11 |
+
class MicrortsCnn(FlattenedCnnEncoder):
|
12 |
+
"""
|
13 |
+
Base CNN architecture for Gym-MicroRTS
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
obs_space: gym.Space,
|
19 |
+
activation: Type[nn.Module],
|
20 |
+
cnn_init_layers_orthogonal: Optional[bool],
|
21 |
+
linear_init_layers_orthogonal: bool,
|
22 |
+
cnn_flatten_dim: int,
|
23 |
+
**kwargs,
|
24 |
+
) -> None:
|
25 |
+
if cnn_init_layers_orthogonal is None:
|
26 |
+
cnn_init_layers_orthogonal = True
|
27 |
+
in_channels = obs_space.shape[0] # type: ignore
|
28 |
+
cnn = nn.Sequential(
|
29 |
+
layer_init(
|
30 |
+
nn.Conv2d(in_channels, 16, kernel_size=3, stride=2),
|
31 |
+
cnn_init_layers_orthogonal,
|
32 |
+
),
|
33 |
+
activation(),
|
34 |
+
layer_init(nn.Conv2d(16, 32, kernel_size=2), cnn_init_layers_orthogonal),
|
35 |
+
activation(),
|
36 |
+
nn.Flatten(),
|
37 |
+
)
|
38 |
+
super().__init__(
|
39 |
+
obs_space,
|
40 |
+
activation,
|
41 |
+
linear_init_layers_orthogonal,
|
42 |
+
cnn_flatten_dim,
|
43 |
+
cnn,
|
44 |
+
**kwargs,
|
45 |
+
)
|
rl_algo_impls/shared/encoder/nature_cnn.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
7 |
+
from rl_algo_impls.shared.module.module import layer_init
|
8 |
+
|
9 |
+
|
10 |
+
class NatureCnn(FlattenedCnnEncoder):
|
11 |
+
"""
|
12 |
+
CNN from DQN Nature paper: Mnih, Volodymyr, et al.
|
13 |
+
"Human-level control through deep reinforcement learning."
|
14 |
+
Nature 518.7540 (2015): 529-533.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
obs_space: gym.Space,
|
20 |
+
activation: Type[nn.Module],
|
21 |
+
cnn_init_layers_orthogonal: Optional[bool],
|
22 |
+
linear_init_layers_orthogonal: bool,
|
23 |
+
cnn_flatten_dim: int,
|
24 |
+
**kwargs,
|
25 |
+
) -> None:
|
26 |
+
if cnn_init_layers_orthogonal is None:
|
27 |
+
cnn_init_layers_orthogonal = True
|
28 |
+
in_channels = obs_space.shape[0] # type: ignore
|
29 |
+
cnn = nn.Sequential(
|
30 |
+
layer_init(
|
31 |
+
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
|
32 |
+
cnn_init_layers_orthogonal,
|
33 |
+
),
|
34 |
+
activation(),
|
35 |
+
layer_init(
|
36 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
37 |
+
cnn_init_layers_orthogonal,
|
38 |
+
),
|
39 |
+
activation(),
|
40 |
+
layer_init(
|
41 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
42 |
+
cnn_init_layers_orthogonal,
|
43 |
+
),
|
44 |
+
activation(),
|
45 |
+
)
|
46 |
+
super().__init__(
|
47 |
+
obs_space,
|
48 |
+
activation,
|
49 |
+
linear_init_layers_orthogonal,
|
50 |
+
cnn_flatten_dim,
|
51 |
+
cnn,
|
52 |
+
**kwargs,
|
53 |
+
)
|
rl_algo_impls/shared/gae.py
CHANGED
@@ -5,6 +5,7 @@ from typing import NamedTuple, Sequence
|
|
5 |
|
6 |
from rl_algo_impls.shared.policy.on_policy import OnPolicy
|
7 |
from rl_algo_impls.shared.trajectory import Trajectory
|
|
|
8 |
|
9 |
|
10 |
class RtgAdvantage(NamedTuple):
|
@@ -19,7 +20,7 @@ def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
|
|
19 |
return dc
|
20 |
|
21 |
|
22 |
-
def
|
23 |
trajectories: Sequence[Trajectory],
|
24 |
policy: OnPolicy,
|
25 |
gamma: float,
|
@@ -40,7 +41,7 @@ def compute_advantage(
|
|
40 |
)
|
41 |
|
42 |
|
43 |
-
def
|
44 |
trajectories: Sequence[Trajectory],
|
45 |
policy: OnPolicy,
|
46 |
gamma: float,
|
@@ -65,3 +66,29 @@ def compute_rtg_and_advantage(
|
|
65 |
),
|
66 |
torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
|
67 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from rl_algo_impls.shared.policy.on_policy import OnPolicy
|
7 |
from rl_algo_impls.shared.trajectory import Trajectory
|
8 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
|
9 |
|
10 |
|
11 |
class RtgAdvantage(NamedTuple):
|
|
|
20 |
return dc
|
21 |
|
22 |
|
23 |
+
def compute_advantage_from_trajectories(
|
24 |
trajectories: Sequence[Trajectory],
|
25 |
policy: OnPolicy,
|
26 |
gamma: float,
|
|
|
41 |
)
|
42 |
|
43 |
|
44 |
+
def compute_rtg_and_advantage_from_trajectories(
|
45 |
trajectories: Sequence[Trajectory],
|
46 |
policy: OnPolicy,
|
47 |
gamma: float,
|
|
|
66 |
),
|
67 |
torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
|
68 |
)
|
69 |
+
|
70 |
+
|
71 |
+
def compute_advantages(
|
72 |
+
rewards: np.ndarray,
|
73 |
+
values: np.ndarray,
|
74 |
+
episode_starts: np.ndarray,
|
75 |
+
next_episode_starts: np.ndarray,
|
76 |
+
next_obs: VecEnvObs,
|
77 |
+
policy: OnPolicy,
|
78 |
+
gamma: float,
|
79 |
+
gae_lambda: float,
|
80 |
+
) -> np.ndarray:
|
81 |
+
advantages = np.zeros_like(rewards)
|
82 |
+
last_gae_lam = 0
|
83 |
+
n_steps = advantages.shape[0]
|
84 |
+
for t in reversed(range(n_steps)):
|
85 |
+
if t == n_steps - 1:
|
86 |
+
next_nonterminal = 1.0 - next_episode_starts
|
87 |
+
next_value = policy.value(next_obs)
|
88 |
+
else:
|
89 |
+
next_nonterminal = 1.0 - episode_starts[t + 1]
|
90 |
+
next_value = values[t + 1]
|
91 |
+
delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
|
92 |
+
last_gae_lam = delta + gamma * gae_lambda * next_nonterminal * last_gae_lam
|
93 |
+
advantages[t] = last_gae_lam
|
94 |
+
return advantages
|
rl_algo_impls/shared/module/feature_extractor.py
DELETED
@@ -1,215 +0,0 @@
|
|
1 |
-
import gym
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from abc import ABC, abstractmethod
|
7 |
-
from gym.spaces import Box, Discrete
|
8 |
-
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
|
9 |
-
from typing import Dict, Optional, Sequence, Type
|
10 |
-
|
11 |
-
from rl_algo_impls.shared.module.module import layer_init
|
12 |
-
|
13 |
-
|
14 |
-
class CnnFeatureExtractor(nn.Module, ABC):
|
15 |
-
@abstractmethod
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
in_channels: int,
|
19 |
-
activation: Type[nn.Module] = nn.ReLU,
|
20 |
-
init_layers_orthogonal: Optional[bool] = None,
|
21 |
-
**kwargs,
|
22 |
-
) -> None:
|
23 |
-
super().__init__()
|
24 |
-
|
25 |
-
|
26 |
-
class NatureCnn(CnnFeatureExtractor):
|
27 |
-
"""
|
28 |
-
CNN from DQN Nature paper: Mnih, Volodymyr, et al.
|
29 |
-
"Human-level control through deep reinforcement learning."
|
30 |
-
Nature 518.7540 (2015): 529-533.
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(
|
34 |
-
self,
|
35 |
-
in_channels: int,
|
36 |
-
activation: Type[nn.Module] = nn.ReLU,
|
37 |
-
init_layers_orthogonal: Optional[bool] = None,
|
38 |
-
**kwargs,
|
39 |
-
) -> None:
|
40 |
-
if init_layers_orthogonal is None:
|
41 |
-
init_layers_orthogonal = True
|
42 |
-
super().__init__(in_channels, activation, init_layers_orthogonal)
|
43 |
-
self.cnn = nn.Sequential(
|
44 |
-
layer_init(
|
45 |
-
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
|
46 |
-
init_layers_orthogonal,
|
47 |
-
),
|
48 |
-
activation(),
|
49 |
-
layer_init(
|
50 |
-
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
51 |
-
init_layers_orthogonal,
|
52 |
-
),
|
53 |
-
activation(),
|
54 |
-
layer_init(
|
55 |
-
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
56 |
-
init_layers_orthogonal,
|
57 |
-
),
|
58 |
-
activation(),
|
59 |
-
nn.Flatten(),
|
60 |
-
)
|
61 |
-
|
62 |
-
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
63 |
-
return self.cnn(obs)
|
64 |
-
|
65 |
-
|
66 |
-
class ResidualBlock(nn.Module):
|
67 |
-
def __init__(
|
68 |
-
self,
|
69 |
-
channels: int,
|
70 |
-
activation: Type[nn.Module] = nn.ReLU,
|
71 |
-
init_layers_orthogonal: bool = False,
|
72 |
-
) -> None:
|
73 |
-
super().__init__()
|
74 |
-
self.residual = nn.Sequential(
|
75 |
-
activation(),
|
76 |
-
layer_init(
|
77 |
-
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
|
78 |
-
),
|
79 |
-
activation(),
|
80 |
-
layer_init(
|
81 |
-
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
|
82 |
-
),
|
83 |
-
)
|
84 |
-
|
85 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
-
return x + self.residual(x)
|
87 |
-
|
88 |
-
|
89 |
-
class ConvSequence(nn.Module):
|
90 |
-
def __init__(
|
91 |
-
self,
|
92 |
-
in_channels: int,
|
93 |
-
out_channels: int,
|
94 |
-
activation: Type[nn.Module] = nn.ReLU,
|
95 |
-
init_layers_orthogonal: bool = False,
|
96 |
-
) -> None:
|
97 |
-
super().__init__()
|
98 |
-
self.seq = nn.Sequential(
|
99 |
-
layer_init(
|
100 |
-
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
101 |
-
init_layers_orthogonal,
|
102 |
-
),
|
103 |
-
nn.MaxPool2d(3, stride=2, padding=1),
|
104 |
-
ResidualBlock(out_channels, activation, init_layers_orthogonal),
|
105 |
-
ResidualBlock(out_channels, activation, init_layers_orthogonal),
|
106 |
-
)
|
107 |
-
|
108 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
109 |
-
return self.seq(x)
|
110 |
-
|
111 |
-
|
112 |
-
class ImpalaCnn(CnnFeatureExtractor):
|
113 |
-
"""
|
114 |
-
IMPALA-style CNN architecture
|
115 |
-
"""
|
116 |
-
|
117 |
-
def __init__(
|
118 |
-
self,
|
119 |
-
in_channels: int,
|
120 |
-
activation: Type[nn.Module] = nn.ReLU,
|
121 |
-
init_layers_orthogonal: Optional[bool] = None,
|
122 |
-
impala_channels: Sequence[int] = (16, 32, 32),
|
123 |
-
**kwargs,
|
124 |
-
) -> None:
|
125 |
-
if init_layers_orthogonal is None:
|
126 |
-
init_layers_orthogonal = False
|
127 |
-
super().__init__(in_channels, activation, init_layers_orthogonal)
|
128 |
-
sequences = []
|
129 |
-
for out_channels in impala_channels:
|
130 |
-
sequences.append(
|
131 |
-
ConvSequence(
|
132 |
-
in_channels, out_channels, activation, init_layers_orthogonal
|
133 |
-
)
|
134 |
-
)
|
135 |
-
in_channels = out_channels
|
136 |
-
sequences.extend(
|
137 |
-
[
|
138 |
-
activation(),
|
139 |
-
nn.Flatten(),
|
140 |
-
]
|
141 |
-
)
|
142 |
-
self.seq = nn.Sequential(*sequences)
|
143 |
-
|
144 |
-
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
145 |
-
return self.seq(obs)
|
146 |
-
|
147 |
-
|
148 |
-
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
|
149 |
-
"nature": NatureCnn,
|
150 |
-
"impala": ImpalaCnn,
|
151 |
-
}
|
152 |
-
|
153 |
-
|
154 |
-
class FeatureExtractor(nn.Module):
|
155 |
-
def __init__(
|
156 |
-
self,
|
157 |
-
obs_space: gym.Space,
|
158 |
-
activation: Type[nn.Module],
|
159 |
-
init_layers_orthogonal: bool = False,
|
160 |
-
cnn_feature_dim: int = 512,
|
161 |
-
cnn_style: str = "nature",
|
162 |
-
cnn_layers_init_orthogonal: Optional[bool] = None,
|
163 |
-
impala_channels: Sequence[int] = (16, 32, 32),
|
164 |
-
) -> None:
|
165 |
-
super().__init__()
|
166 |
-
if isinstance(obs_space, Box):
|
167 |
-
# Conv2D: (channels, height, width)
|
168 |
-
if len(obs_space.shape) == 3:
|
169 |
-
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
|
170 |
-
obs_space.shape[0],
|
171 |
-
activation,
|
172 |
-
init_layers_orthogonal=cnn_layers_init_orthogonal,
|
173 |
-
impala_channels=impala_channels,
|
174 |
-
)
|
175 |
-
|
176 |
-
def preprocess(obs: torch.Tensor) -> torch.Tensor:
|
177 |
-
if len(obs.shape) == 3:
|
178 |
-
obs = obs.unsqueeze(0)
|
179 |
-
return obs.float() / 255.0
|
180 |
-
|
181 |
-
with torch.no_grad():
|
182 |
-
cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
|
183 |
-
self.preprocess = preprocess
|
184 |
-
self.feature_extractor = nn.Sequential(
|
185 |
-
cnn,
|
186 |
-
layer_init(
|
187 |
-
nn.Linear(cnn_out.shape[1], cnn_feature_dim),
|
188 |
-
init_layers_orthogonal,
|
189 |
-
),
|
190 |
-
activation(),
|
191 |
-
)
|
192 |
-
self.out_dim = cnn_feature_dim
|
193 |
-
elif len(obs_space.shape) == 1:
|
194 |
-
|
195 |
-
def preprocess(obs: torch.Tensor) -> torch.Tensor:
|
196 |
-
if len(obs.shape) == 1:
|
197 |
-
obs = obs.unsqueeze(0)
|
198 |
-
return obs.float()
|
199 |
-
|
200 |
-
self.preprocess = preprocess
|
201 |
-
self.feature_extractor = nn.Flatten()
|
202 |
-
self.out_dim = get_flattened_obs_dim(obs_space)
|
203 |
-
else:
|
204 |
-
raise ValueError(f"Unsupported observation space: {obs_space}")
|
205 |
-
elif isinstance(obs_space, Discrete):
|
206 |
-
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
|
207 |
-
self.feature_extractor = nn.Flatten()
|
208 |
-
self.out_dim = obs_space.n
|
209 |
-
else:
|
210 |
-
raise NotImplementedError
|
211 |
-
|
212 |
-
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
213 |
-
if self.preprocess:
|
214 |
-
obs = self.preprocess(obs)
|
215 |
-
return self.feature_extractor(obs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rl_algo_impls/shared/module/module.py
CHANGED
@@ -1,8 +1,8 @@
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch.nn as nn
|
3 |
|
4 |
-
from typing import Sequence, Type
|
5 |
-
|
6 |
|
7 |
def mlp(
|
8 |
layer_sizes: Sequence[int],
|
@@ -10,12 +10,15 @@ def mlp(
|
|
10 |
output_activation: Type[nn.Module] = nn.Identity,
|
11 |
init_layers_orthogonal: bool = False,
|
12 |
final_layer_gain: float = np.sqrt(2),
|
|
|
13 |
) -> nn.Module:
|
14 |
layers = []
|
15 |
for i in range(len(layer_sizes) - 2):
|
16 |
layers.append(
|
17 |
layer_init(
|
18 |
-
nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
|
|
|
|
|
19 |
)
|
20 |
)
|
21 |
layers.append(activation())
|
|
|
1 |
+
from typing import Sequence, Type
|
2 |
+
|
3 |
import numpy as np
|
4 |
import torch.nn as nn
|
5 |
|
|
|
|
|
6 |
|
7 |
def mlp(
|
8 |
layer_sizes: Sequence[int],
|
|
|
10 |
output_activation: Type[nn.Module] = nn.Identity,
|
11 |
init_layers_orthogonal: bool = False,
|
12 |
final_layer_gain: float = np.sqrt(2),
|
13 |
+
hidden_layer_gain: float = np.sqrt(2),
|
14 |
) -> nn.Module:
|
15 |
layers = []
|
16 |
for i in range(len(layer_sizes) - 2):
|
17 |
layers.append(
|
18 |
layer_init(
|
19 |
+
nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
|
20 |
+
init_layers_orthogonal,
|
21 |
+
std=hidden_layer_gain,
|
22 |
)
|
23 |
)
|
24 |
layers.append(activation())
|
rl_algo_impls/shared/policy/critic.py
CHANGED
@@ -1,27 +1,39 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
|
5 |
-
from
|
6 |
-
|
7 |
from rl_algo_impls.shared.module.module import mlp
|
8 |
|
9 |
|
10 |
class CriticHead(nn.Module):
|
11 |
def __init__(
|
12 |
self,
|
13 |
-
|
|
|
14 |
activation: Type[nn.Module] = nn.Tanh,
|
15 |
init_layers_orthogonal: bool = True,
|
16 |
) -> None:
|
17 |
super().__init__()
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
)
|
|
|
25 |
|
26 |
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
27 |
v = self._fc(obs)
|
|
|
1 |
+
from typing import Sequence, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
|
7 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
|
|
8 |
from rl_algo_impls.shared.module.module import mlp
|
9 |
|
10 |
|
11 |
class CriticHead(nn.Module):
|
12 |
def __init__(
|
13 |
self,
|
14 |
+
in_dim: EncoderOutDim,
|
15 |
+
hidden_sizes: Sequence[int] = (),
|
16 |
activation: Type[nn.Module] = nn.Tanh,
|
17 |
init_layers_orthogonal: bool = True,
|
18 |
) -> None:
|
19 |
super().__init__()
|
20 |
+
seq = []
|
21 |
+
if isinstance(in_dim, tuple):
|
22 |
+
seq.append(nn.Flatten())
|
23 |
+
in_channels = int(np.prod(in_dim))
|
24 |
+
else:
|
25 |
+
in_channels = in_dim
|
26 |
+
layer_sizes = (in_channels,) + tuple(hidden_sizes) + (1,)
|
27 |
+
seq.append(
|
28 |
+
mlp(
|
29 |
+
layer_sizes,
|
30 |
+
activation,
|
31 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
32 |
+
final_layer_gain=1.0,
|
33 |
+
hidden_layer_gain=1.0,
|
34 |
+
)
|
35 |
)
|
36 |
+
self._fc = nn.Sequential(*seq)
|
37 |
|
38 |
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
39 |
v = self._fc(obs)
|
rl_algo_impls/shared/policy/on_policy.py
CHANGED
@@ -1,24 +1,20 @@
|
|
|
|
|
|
|
|
1 |
import gym
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
|
5 |
-
from abc import abstractmethod
|
6 |
from gym.spaces import Box, Discrete, Space
|
7 |
-
from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
|
8 |
|
9 |
-
from rl_algo_impls.shared.
|
10 |
-
from rl_algo_impls.shared.
|
11 |
-
PiForward,
|
12 |
-
StateDependentNoiseActorHead,
|
13 |
-
actor_head,
|
14 |
-
)
|
15 |
from rl_algo_impls.shared.policy.critic import CriticHead
|
16 |
from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
|
17 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
18 |
VecEnv,
|
19 |
VecEnvObs,
|
20 |
-
single_observation_space,
|
21 |
single_action_space,
|
|
|
22 |
)
|
23 |
|
24 |
|
@@ -77,7 +73,12 @@ class OnPolicy(Policy):
|
|
77 |
...
|
78 |
|
79 |
@abstractmethod
|
80 |
-
def step(self, obs: VecEnvObs) -> Step:
|
|
|
|
|
|
|
|
|
|
|
81 |
...
|
82 |
|
83 |
|
@@ -94,10 +95,11 @@ class ActorCritic(OnPolicy):
|
|
94 |
full_std: bool = True,
|
95 |
squash_output: bool = False,
|
96 |
share_features_extractor: bool = True,
|
97 |
-
|
98 |
cnn_style: str = "nature",
|
99 |
cnn_layers_init_orthogonal: Optional[bool] = None,
|
100 |
impala_channels: Sequence[int] = (16, 32, 32),
|
|
|
101 |
**kwargs,
|
102 |
) -> None:
|
103 |
super().__init__(env, **kwargs)
|
@@ -120,52 +122,56 @@ class ActorCritic(OnPolicy):
|
|
120 |
self.action_space = action_space
|
121 |
self.squash_output = squash_output
|
122 |
self.share_features_extractor = share_features_extractor
|
123 |
-
self._feature_extractor =
|
124 |
observation_space,
|
125 |
activation,
|
126 |
init_layers_orthogonal=init_layers_orthogonal,
|
127 |
-
|
128 |
cnn_style=cnn_style,
|
129 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
130 |
impala_channels=impala_channels,
|
131 |
)
|
132 |
self._pi = actor_head(
|
133 |
self.action_space,
|
134 |
-
|
|
|
135 |
init_layers_orthogonal,
|
136 |
activation,
|
137 |
log_std_init=log_std_init,
|
138 |
use_sde=use_sde,
|
139 |
full_std=full_std,
|
140 |
squash_output=squash_output,
|
|
|
141 |
)
|
142 |
|
143 |
if not share_features_extractor:
|
144 |
-
self._v_feature_extractor =
|
145 |
observation_space,
|
146 |
activation,
|
147 |
init_layers_orthogonal=init_layers_orthogonal,
|
148 |
-
|
149 |
cnn_style=cnn_style,
|
150 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
151 |
)
|
152 |
-
|
153 |
-
v_hidden_sizes
|
154 |
-
)
|
155 |
else:
|
156 |
self._v_feature_extractor = None
|
157 |
-
|
158 |
self._v = CriticHead(
|
|
|
159 |
hidden_sizes=v_hidden_sizes,
|
160 |
activation=activation,
|
161 |
init_layers_orthogonal=init_layers_orthogonal,
|
162 |
)
|
163 |
|
164 |
def _pi_forward(
|
165 |
-
self,
|
|
|
|
|
|
|
166 |
) -> Tuple[PiForward, torch.Tensor]:
|
167 |
p_fe = self._feature_extractor(obs)
|
168 |
-
pi_forward = self._pi(p_fe, action)
|
169 |
|
170 |
return pi_forward, p_fe
|
171 |
|
@@ -173,8 +179,13 @@ class ActorCritic(OnPolicy):
|
|
173 |
v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
|
174 |
return self._v(v_fe)
|
175 |
|
176 |
-
def forward(
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
178 |
v = self._v_forward(obs, p_fc)
|
179 |
|
180 |
assert logp_a is not None
|
@@ -192,10 +203,11 @@ class ActorCritic(OnPolicy):
|
|
192 |
v = self._v(fe)
|
193 |
return v.cpu().numpy()
|
194 |
|
195 |
-
def step(self, obs: VecEnvObs) -> Step:
|
196 |
o = self._as_tensor(obs)
|
|
|
197 |
with torch.no_grad():
|
198 |
-
(pi, _, _), p_fc = self._pi_forward(o)
|
199 |
a = pi.sample()
|
200 |
logp_a = pi.log_prob(a)
|
201 |
|
@@ -205,13 +217,21 @@ class ActorCritic(OnPolicy):
|
|
205 |
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
|
206 |
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
|
207 |
|
208 |
-
def act(
|
|
|
|
|
|
|
|
|
|
|
209 |
if not deterministic:
|
210 |
-
return self.step(obs).clamped_a
|
211 |
else:
|
212 |
o = self._as_tensor(obs)
|
|
|
|
|
|
|
213 |
with torch.no_grad():
|
214 |
-
(pi, _, _), _ = self._pi_forward(o)
|
215 |
a = pi.mode
|
216 |
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
|
217 |
|
@@ -220,7 +240,10 @@ class ActorCritic(OnPolicy):
|
|
220 |
self.reset_noise()
|
221 |
|
222 |
def reset_noise(self, batch_size: Optional[int] = None) -> None:
|
223 |
-
|
224 |
-
self.
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
|
3 |
+
|
4 |
import gym
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
|
|
7 |
from gym.spaces import Box, Discrete, Space
|
|
|
8 |
|
9 |
+
from rl_algo_impls.shared.actor import PiForward, actor_head
|
10 |
+
from rl_algo_impls.shared.encoder import Encoder
|
|
|
|
|
|
|
|
|
11 |
from rl_algo_impls.shared.policy.critic import CriticHead
|
12 |
from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
|
13 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
14 |
VecEnv,
|
15 |
VecEnvObs,
|
|
|
16 |
single_action_space,
|
17 |
+
single_observation_space,
|
18 |
)
|
19 |
|
20 |
|
|
|
73 |
...
|
74 |
|
75 |
@abstractmethod
|
76 |
+
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
|
77 |
+
...
|
78 |
+
|
79 |
+
@property
|
80 |
+
@abstractmethod
|
81 |
+
def action_shape(self) -> Tuple[int, ...]:
|
82 |
...
|
83 |
|
84 |
|
|
|
95 |
full_std: bool = True,
|
96 |
squash_output: bool = False,
|
97 |
share_features_extractor: bool = True,
|
98 |
+
cnn_flatten_dim: int = 512,
|
99 |
cnn_style: str = "nature",
|
100 |
cnn_layers_init_orthogonal: Optional[bool] = None,
|
101 |
impala_channels: Sequence[int] = (16, 32, 32),
|
102 |
+
actor_head_style: str = "single",
|
103 |
**kwargs,
|
104 |
) -> None:
|
105 |
super().__init__(env, **kwargs)
|
|
|
122 |
self.action_space = action_space
|
123 |
self.squash_output = squash_output
|
124 |
self.share_features_extractor = share_features_extractor
|
125 |
+
self._feature_extractor = Encoder(
|
126 |
observation_space,
|
127 |
activation,
|
128 |
init_layers_orthogonal=init_layers_orthogonal,
|
129 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
130 |
cnn_style=cnn_style,
|
131 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
132 |
impala_channels=impala_channels,
|
133 |
)
|
134 |
self._pi = actor_head(
|
135 |
self.action_space,
|
136 |
+
self._feature_extractor.out_dim,
|
137 |
+
tuple(pi_hidden_sizes),
|
138 |
init_layers_orthogonal,
|
139 |
activation,
|
140 |
log_std_init=log_std_init,
|
141 |
use_sde=use_sde,
|
142 |
full_std=full_std,
|
143 |
squash_output=squash_output,
|
144 |
+
actor_head_style=actor_head_style,
|
145 |
)
|
146 |
|
147 |
if not share_features_extractor:
|
148 |
+
self._v_feature_extractor = Encoder(
|
149 |
observation_space,
|
150 |
activation,
|
151 |
init_layers_orthogonal=init_layers_orthogonal,
|
152 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
153 |
cnn_style=cnn_style,
|
154 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
155 |
)
|
156 |
+
critic_in_dim = self._v_feature_extractor.out_dim
|
|
|
|
|
157 |
else:
|
158 |
self._v_feature_extractor = None
|
159 |
+
critic_in_dim = self._feature_extractor.out_dim
|
160 |
self._v = CriticHead(
|
161 |
+
in_dim=critic_in_dim,
|
162 |
hidden_sizes=v_hidden_sizes,
|
163 |
activation=activation,
|
164 |
init_layers_orthogonal=init_layers_orthogonal,
|
165 |
)
|
166 |
|
167 |
def _pi_forward(
|
168 |
+
self,
|
169 |
+
obs: torch.Tensor,
|
170 |
+
action_masks: Optional[torch.Tensor],
|
171 |
+
action: Optional[torch.Tensor] = None,
|
172 |
) -> Tuple[PiForward, torch.Tensor]:
|
173 |
p_fe = self._feature_extractor(obs)
|
174 |
+
pi_forward = self._pi(p_fe, actions=action, action_masks=action_masks)
|
175 |
|
176 |
return pi_forward, p_fe
|
177 |
|
|
|
179 |
v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
|
180 |
return self._v(v_fe)
|
181 |
|
182 |
+
def forward(
|
183 |
+
self,
|
184 |
+
obs: torch.Tensor,
|
185 |
+
action: torch.Tensor,
|
186 |
+
action_masks: Optional[torch.Tensor] = None,
|
187 |
+
) -> ACForward:
|
188 |
+
(_, logp_a, entropy), p_fc = self._pi_forward(obs, action_masks, action=action)
|
189 |
v = self._v_forward(obs, p_fc)
|
190 |
|
191 |
assert logp_a is not None
|
|
|
203 |
v = self._v(fe)
|
204 |
return v.cpu().numpy()
|
205 |
|
206 |
+
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
|
207 |
o = self._as_tensor(obs)
|
208 |
+
a_masks = self._as_tensor(action_masks) if action_masks is not None else None
|
209 |
with torch.no_grad():
|
210 |
+
(pi, _, _), p_fc = self._pi_forward(o, action_masks=a_masks)
|
211 |
a = pi.sample()
|
212 |
logp_a = pi.log_prob(a)
|
213 |
|
|
|
217 |
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
|
218 |
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
|
219 |
|
220 |
+
def act(
|
221 |
+
self,
|
222 |
+
obs: np.ndarray,
|
223 |
+
deterministic: bool = True,
|
224 |
+
action_masks: Optional[np.ndarray] = None,
|
225 |
+
) -> np.ndarray:
|
226 |
if not deterministic:
|
227 |
+
return self.step(obs, action_masks=action_masks).clamped_a
|
228 |
else:
|
229 |
o = self._as_tensor(obs)
|
230 |
+
a_masks = (
|
231 |
+
self._as_tensor(action_masks) if action_masks is not None else None
|
232 |
+
)
|
233 |
with torch.no_grad():
|
234 |
+
(pi, _, _), _ = self._pi_forward(o, action_masks=a_masks)
|
235 |
a = pi.mode
|
236 |
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
|
237 |
|
|
|
240 |
self.reset_noise()
|
241 |
|
242 |
def reset_noise(self, batch_size: Optional[int] = None) -> None:
|
243 |
+
self._pi.sample_weights(
|
244 |
+
batch_size=batch_size if batch_size else self.env.num_envs
|
245 |
+
)
|
246 |
+
|
247 |
+
@property
|
248 |
+
def action_shape(self) -> Tuple[int, ...]:
|
249 |
+
return self._pi.action_shape
|
rl_algo_impls/shared/policy/policy.py
CHANGED
@@ -46,7 +46,12 @@ class Policy(nn.Module, ABC):
|
|
46 |
return self
|
47 |
|
48 |
@abstractmethod
|
49 |
-
def act(
|
|
|
|
|
|
|
|
|
|
|
50 |
...
|
51 |
|
52 |
def save(self, path: str) -> None:
|
|
|
46 |
return self
|
47 |
|
48 |
@abstractmethod
|
49 |
+
def act(
|
50 |
+
self,
|
51 |
+
obs: VecEnvObs,
|
52 |
+
deterministic: bool = True,
|
53 |
+
action_masks: Optional[np.ndarray] = None,
|
54 |
+
) -> np.ndarray:
|
55 |
...
|
56 |
|
57 |
def save(self, path: str) -> None:
|
rl_algo_impls/shared/schedule.py
CHANGED
@@ -20,10 +20,38 @@ def constant_schedule(val: float) -> Schedule:
|
|
20 |
return lambda f: val
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def schedule(name: str, start_val: float) -> Schedule:
|
24 |
if name == "linear":
|
25 |
return linear_schedule(start_val, 0)
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
|
|
|
20 |
return lambda f: val
|
21 |
|
22 |
|
23 |
+
def spike_schedule(
|
24 |
+
max_value: float,
|
25 |
+
start_fraction: float = 1e-2,
|
26 |
+
end_fraction: float = 1e-4,
|
27 |
+
peak_progress: float = 0.1,
|
28 |
+
) -> Schedule:
|
29 |
+
assert 0 < peak_progress < 1
|
30 |
+
|
31 |
+
def func(progress_fraction: float) -> float:
|
32 |
+
if progress_fraction < peak_progress:
|
33 |
+
fraction = (
|
34 |
+
start_fraction
|
35 |
+
+ (1 - start_fraction) * progress_fraction / peak_progress
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
fraction = 1 + (end_fraction - 1) * (progress_fraction - peak_progress) / (
|
39 |
+
1 - peak_progress
|
40 |
+
)
|
41 |
+
return max_value * fraction
|
42 |
+
|
43 |
+
return func
|
44 |
+
|
45 |
+
|
46 |
def schedule(name: str, start_val: float) -> Schedule:
|
47 |
if name == "linear":
|
48 |
return linear_schedule(start_val, 0)
|
49 |
+
elif name == "none":
|
50 |
+
return constant_schedule(start_val)
|
51 |
+
elif name == "spike":
|
52 |
+
return spike_schedule(start_val)
|
53 |
+
else:
|
54 |
+
raise ValueError(f"Schedule {name} not supported")
|
55 |
|
56 |
|
57 |
def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
|
rl_algo_impls/shared/stats.py
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
from dataclasses import dataclass
|
|
|
|
|
|
|
4 |
from torch.utils.tensorboard.writer import SummaryWriter
|
5 |
-
from typing import Dict, List, Optional, Sequence, Union, TypeVar
|
6 |
|
7 |
|
8 |
@dataclass
|
9 |
class Episode:
|
10 |
score: float = 0
|
11 |
length: int = 0
|
|
|
12 |
|
13 |
|
14 |
StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
|
@@ -75,12 +78,25 @@ class EpisodesStats:
|
|
75 |
simple: bool
|
76 |
score: Statistic
|
77 |
length: Statistic
|
|
|
78 |
|
79 |
def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
|
80 |
self.episodes = episodes
|
81 |
self.simple = simple
|
82 |
self.score = Statistic(np.array([e.score for e in episodes]))
|
83 |
self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
|
86 |
return self.score > o.score
|
@@ -118,6 +134,8 @@ class EpisodesStats:
|
|
118 |
"length": self.length.mean,
|
119 |
}
|
120 |
)
|
|
|
|
|
121 |
for name, value in stats.items():
|
122 |
tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
|
123 |
|
@@ -131,19 +149,19 @@ class EpisodeAccumulator:
|
|
131 |
def episodes(self) -> List[Episode]:
|
132 |
return self._episodes
|
133 |
|
134 |
-
def step(self, reward: np.ndarray, done: np.ndarray) -> None:
|
135 |
for idx, current in enumerate(self.current_episodes):
|
136 |
current.score += reward[idx]
|
137 |
current.length += 1
|
138 |
if done[idx]:
|
139 |
self._episodes.append(current)
|
140 |
self.current_episodes[idx] = Episode()
|
141 |
-
self.on_done(idx, current)
|
142 |
|
143 |
def __len__(self) -> int:
|
144 |
return len(self.episodes)
|
145 |
|
146 |
-
def on_done(self, ep_idx: int, episode: Episode) -> None:
|
147 |
pass
|
148 |
|
149 |
def stats(self) -> EpisodesStats:
|
|
|
1 |
+
import dataclasses
|
2 |
+
from collections import defaultdict
|
3 |
from dataclasses import dataclass
|
4 |
+
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
8 |
|
9 |
|
10 |
@dataclass
|
11 |
class Episode:
|
12 |
score: float = 0
|
13 |
length: int = 0
|
14 |
+
info: Dict[str, Dict[str, Any]] = dataclasses.field(default_factory=dict)
|
15 |
|
16 |
|
17 |
StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
|
|
|
78 |
simple: bool
|
79 |
score: Statistic
|
80 |
length: Statistic
|
81 |
+
additional_stats: Dict[str, Statistic]
|
82 |
|
83 |
def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
|
84 |
self.episodes = episodes
|
85 |
self.simple = simple
|
86 |
self.score = Statistic(np.array([e.score for e in episodes]))
|
87 |
self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
|
88 |
+
additional_values = defaultdict(list)
|
89 |
+
for e in self.episodes:
|
90 |
+
if e.info:
|
91 |
+
for k, v in e.info.items():
|
92 |
+
if isinstance(v, dict):
|
93 |
+
for k2, v2 in v.items():
|
94 |
+
additional_values[f"{k}_{k2}"].append(v2)
|
95 |
+
else:
|
96 |
+
additional_values[k].append(v)
|
97 |
+
self.additional_stats = {
|
98 |
+
k: Statistic(np.array(values)) for k, values in additional_values.items()
|
99 |
+
}
|
100 |
|
101 |
def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
|
102 |
return self.score > o.score
|
|
|
134 |
"length": self.length.mean,
|
135 |
}
|
136 |
)
|
137 |
+
for k, addl_stats in self.additional_stats.items():
|
138 |
+
stats[k] = addl_stats.mean
|
139 |
for name, value in stats.items():
|
140 |
tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
|
141 |
|
|
|
149 |
def episodes(self) -> List[Episode]:
|
150 |
return self._episodes
|
151 |
|
152 |
+
def step(self, reward: np.ndarray, done: np.ndarray, info: List[Dict]) -> None:
|
153 |
for idx, current in enumerate(self.current_episodes):
|
154 |
current.score += reward[idx]
|
155 |
current.length += 1
|
156 |
if done[idx]:
|
157 |
self._episodes.append(current)
|
158 |
self.current_episodes[idx] = Episode()
|
159 |
+
self.on_done(idx, current, info[idx])
|
160 |
|
161 |
def __len__(self) -> int:
|
162 |
return len(self.episodes)
|
163 |
|
164 |
+
def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
|
165 |
pass
|
166 |
|
167 |
def stats(self) -> EpisodesStats:
|
rl_algo_impls/shared/vec_env/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from rl_algo_impls.shared.vec_env.make_env import make_env, make_eval_env
|
rl_algo_impls/shared/vec_env/make_env.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import asdict
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
5 |
+
|
6 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams
|
7 |
+
from rl_algo_impls.shared.vec_env.microrts import make_microrts_env
|
8 |
+
from rl_algo_impls.shared.vec_env.procgen import make_procgen_env
|
9 |
+
from rl_algo_impls.shared.vec_env.vec_env import make_vec_env
|
10 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
11 |
+
|
12 |
+
|
13 |
+
def make_env(
|
14 |
+
config: Config,
|
15 |
+
hparams: EnvHyperparams,
|
16 |
+
training: bool = True,
|
17 |
+
render: bool = False,
|
18 |
+
normalize_load_path: Optional[str] = None,
|
19 |
+
tb_writer: Optional[SummaryWriter] = None,
|
20 |
+
) -> VecEnv:
|
21 |
+
if hparams.env_type == "procgen":
|
22 |
+
return make_procgen_env(
|
23 |
+
config,
|
24 |
+
hparams,
|
25 |
+
training=training,
|
26 |
+
render=render,
|
27 |
+
normalize_load_path=normalize_load_path,
|
28 |
+
tb_writer=tb_writer,
|
29 |
+
)
|
30 |
+
elif hparams.env_type in {"sb3vec", "gymvec"}:
|
31 |
+
return make_vec_env(
|
32 |
+
config,
|
33 |
+
hparams,
|
34 |
+
training=training,
|
35 |
+
render=render,
|
36 |
+
normalize_load_path=normalize_load_path,
|
37 |
+
tb_writer=tb_writer,
|
38 |
+
)
|
39 |
+
elif hparams.env_type == "microrts":
|
40 |
+
return make_microrts_env(
|
41 |
+
config,
|
42 |
+
hparams,
|
43 |
+
training=training,
|
44 |
+
render=render,
|
45 |
+
normalize_load_path=normalize_load_path,
|
46 |
+
tb_writer=tb_writer,
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
raise ValueError(f"env_type {hparams.env_type} not supported")
|
50 |
+
|
51 |
+
|
52 |
+
def make_eval_env(
|
53 |
+
config: Config,
|
54 |
+
hparams: EnvHyperparams,
|
55 |
+
override_n_envs: Optional[int] = None,
|
56 |
+
**kwargs,
|
57 |
+
) -> VecEnv:
|
58 |
+
kwargs = kwargs.copy()
|
59 |
+
kwargs["training"] = False
|
60 |
+
if override_n_envs is not None:
|
61 |
+
hparams_kwargs = asdict(hparams)
|
62 |
+
hparams_kwargs["n_envs"] = override_n_envs
|
63 |
+
if override_n_envs == 1:
|
64 |
+
hparams_kwargs["vec_env_class"] = "sync"
|
65 |
+
hparams = EnvHyperparams(**hparams_kwargs)
|
66 |
+
return make_env(config, hparams, **kwargs)
|
rl_algo_impls/shared/vec_env/microrts.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import astuple
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
7 |
+
|
8 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams
|
9 |
+
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
|
10 |
+
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
|
11 |
+
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
|
12 |
+
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
|
13 |
+
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
|
14 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
15 |
+
|
16 |
+
|
17 |
+
def make_microrts_env(
|
18 |
+
config: Config,
|
19 |
+
hparams: EnvHyperparams,
|
20 |
+
training: bool = True,
|
21 |
+
render: bool = False,
|
22 |
+
normalize_load_path: Optional[str] = None,
|
23 |
+
tb_writer: Optional[SummaryWriter] = None,
|
24 |
+
) -> VecEnv:
|
25 |
+
import gym_microrts
|
26 |
+
from gym_microrts import microrts_ai
|
27 |
+
|
28 |
+
from rl_algo_impls.shared.vec_env.microrts_compat import (
|
29 |
+
MicroRTSGridModeVecEnvCompat,
|
30 |
+
)
|
31 |
+
|
32 |
+
(
|
33 |
+
_, # env_type
|
34 |
+
n_envs,
|
35 |
+
_, # frame_stack
|
36 |
+
make_kwargs,
|
37 |
+
_, # no_reward_timeout_steps
|
38 |
+
_, # no_reward_fire_steps
|
39 |
+
_, # vec_env_class
|
40 |
+
_, # normalize
|
41 |
+
_, # normalize_kwargs,
|
42 |
+
rolling_length,
|
43 |
+
_, # train_record_video
|
44 |
+
_, # video_step_interval
|
45 |
+
_, # initial_steps_to_truncate
|
46 |
+
_, # clip_atari_rewards
|
47 |
+
_, # normalize_type
|
48 |
+
_, # mask_actions
|
49 |
+
bots,
|
50 |
+
) = astuple(hparams)
|
51 |
+
|
52 |
+
seed = config.seed(training=training)
|
53 |
+
|
54 |
+
make_kwargs = make_kwargs or {}
|
55 |
+
if "num_selfplay_envs" not in make_kwargs:
|
56 |
+
make_kwargs["num_selfplay_envs"] = 0
|
57 |
+
if "num_bot_envs" not in make_kwargs:
|
58 |
+
make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"]
|
59 |
+
if "reward_weight" in make_kwargs:
|
60 |
+
make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
|
61 |
+
if bots:
|
62 |
+
ai2s = []
|
63 |
+
for ai_name, n in bots.items():
|
64 |
+
for _ in range(n):
|
65 |
+
if len(ai2s) >= make_kwargs["num_bot_envs"]:
|
66 |
+
break
|
67 |
+
ai = getattr(microrts_ai, ai_name)
|
68 |
+
assert ai, f"{ai_name} not in microrts_ai"
|
69 |
+
ai2s.append(ai)
|
70 |
+
else:
|
71 |
+
ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]]
|
72 |
+
make_kwargs["ai2s"] = ai2s
|
73 |
+
envs = MicroRTSGridModeVecEnvCompat(**make_kwargs)
|
74 |
+
envs = HwcToChwObservation(envs)
|
75 |
+
envs = IsVectorEnv(envs)
|
76 |
+
envs = MicrortsMaskWrapper(envs)
|
77 |
+
|
78 |
+
if seed is not None:
|
79 |
+
envs.action_space.seed(seed)
|
80 |
+
envs.observation_space.seed(seed)
|
81 |
+
|
82 |
+
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
83 |
+
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99))
|
84 |
+
if training:
|
85 |
+
assert tb_writer
|
86 |
+
envs = EpisodeStatsWriter(
|
87 |
+
envs,
|
88 |
+
tb_writer,
|
89 |
+
training=training,
|
90 |
+
rolling_length=rolling_length,
|
91 |
+
additional_keys_to_log=config.additional_keys_to_log,
|
92 |
+
)
|
93 |
+
|
94 |
+
return envs
|
rl_algo_impls/shared/vec_env/microrts_compat.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypeVar
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
|
5 |
+
from jpype.types import JArray, JInt
|
6 |
+
|
7 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvStepReturn
|
8 |
+
|
9 |
+
MicroRTSGridModeVecEnvCompatSelf = TypeVar(
|
10 |
+
"MicroRTSGridModeVecEnvCompatSelf", bound="MicroRTSGridModeVecEnvCompat"
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class MicroRTSGridModeVecEnvCompat(MicroRTSGridModeVecEnv):
|
15 |
+
def step(self, action: np.ndarray) -> VecEnvStepReturn:
|
16 |
+
indexed_actions = np.concatenate(
|
17 |
+
[
|
18 |
+
np.expand_dims(
|
19 |
+
np.stack(
|
20 |
+
[np.arange(0, action.shape[1]) for i in range(self.num_envs)]
|
21 |
+
),
|
22 |
+
axis=2,
|
23 |
+
),
|
24 |
+
action,
|
25 |
+
],
|
26 |
+
axis=2,
|
27 |
+
)
|
28 |
+
action_mask = np.array(self.vec_client.getMasks(0), dtype=np.bool8).reshape(
|
29 |
+
indexed_actions.shape[:-1] + (-1,)
|
30 |
+
)
|
31 |
+
valid_action_mask = action_mask[:, :, 0]
|
32 |
+
valid_actions_counts = valid_action_mask.sum(1)
|
33 |
+
valid_actions = indexed_actions[valid_action_mask]
|
34 |
+
valid_actions_idx = 0
|
35 |
+
|
36 |
+
all_valid_actions = []
|
37 |
+
for env_act_cnt in valid_actions_counts:
|
38 |
+
env_valid_actions = []
|
39 |
+
for _ in range(env_act_cnt):
|
40 |
+
env_valid_actions.append(JArray(JInt)(valid_actions[valid_actions_idx]))
|
41 |
+
valid_actions_idx += 1
|
42 |
+
all_valid_actions.append(JArray(JArray(JInt))(env_valid_actions))
|
43 |
+
return super().step(JArray(JArray(JArray(JInt)))(all_valid_actions)) # type: ignore
|
44 |
+
|
45 |
+
@property
|
46 |
+
def unwrapped(
|
47 |
+
self: MicroRTSGridModeVecEnvCompatSelf,
|
48 |
+
) -> MicroRTSGridModeVecEnvCompatSelf:
|
49 |
+
return self
|
rl_algo_impls/shared/vec_env/procgen.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import astuple
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
7 |
+
|
8 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams
|
9 |
+
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
|
10 |
+
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
|
11 |
+
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
|
12 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
13 |
+
|
14 |
+
|
15 |
+
def make_procgen_env(
|
16 |
+
config: Config,
|
17 |
+
hparams: EnvHyperparams,
|
18 |
+
training: bool = True,
|
19 |
+
render: bool = False,
|
20 |
+
normalize_load_path: Optional[str] = None,
|
21 |
+
tb_writer: Optional[SummaryWriter] = None,
|
22 |
+
) -> VecEnv:
|
23 |
+
from gym3 import ExtractDictObWrapper, ViewerWrapper
|
24 |
+
from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
|
25 |
+
|
26 |
+
(
|
27 |
+
_, # env_type
|
28 |
+
n_envs,
|
29 |
+
_, # frame_stack
|
30 |
+
make_kwargs,
|
31 |
+
_, # no_reward_timeout_steps
|
32 |
+
_, # no_reward_fire_steps
|
33 |
+
_, # vec_env_class
|
34 |
+
normalize,
|
35 |
+
normalize_kwargs,
|
36 |
+
rolling_length,
|
37 |
+
_, # train_record_video
|
38 |
+
_, # video_step_interval
|
39 |
+
_, # initial_steps_to_truncate
|
40 |
+
_, # clip_atari_rewards
|
41 |
+
_, # normalize_type
|
42 |
+
_, # mask_actions
|
43 |
+
_, # bots
|
44 |
+
) = astuple(hparams)
|
45 |
+
|
46 |
+
seed = config.seed(training=training)
|
47 |
+
|
48 |
+
make_kwargs = make_kwargs or {}
|
49 |
+
make_kwargs["render_mode"] = "rgb_array"
|
50 |
+
if seed is not None:
|
51 |
+
make_kwargs["rand_seed"] = seed
|
52 |
+
|
53 |
+
envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
|
54 |
+
envs = ExtractDictObWrapper(envs, key="rgb")
|
55 |
+
if render:
|
56 |
+
envs = ViewerWrapper(envs, info_key="rgb")
|
57 |
+
envs = ToBaselinesVecEnv(envs)
|
58 |
+
envs = IsVectorEnv(envs)
|
59 |
+
# TODO: Handle Grayscale and/or FrameStack
|
60 |
+
envs = HwcToChwObservation(envs)
|
61 |
+
|
62 |
+
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
63 |
+
|
64 |
+
if seed is not None:
|
65 |
+
envs.action_space.seed(seed)
|
66 |
+
envs.observation_space.seed(seed)
|
67 |
+
|
68 |
+
if training:
|
69 |
+
assert tb_writer
|
70 |
+
envs = EpisodeStatsWriter(
|
71 |
+
envs, tb_writer, training=training, rolling_length=rolling_length
|
72 |
+
)
|
73 |
+
if normalize and training:
|
74 |
+
normalize_kwargs = normalize_kwargs or {}
|
75 |
+
envs = gym.wrappers.NormalizeReward(envs)
|
76 |
+
clip_obs = normalize_kwargs.get("clip_reward", 10.0)
|
77 |
+
envs = gym.wrappers.TransformReward(
|
78 |
+
envs, lambda r: np.clip(r, -clip_obs, clip_obs)
|
79 |
+
)
|
80 |
+
|
81 |
+
return envs # type: ignore
|