vwxyzjn commited on
Commit
cda4ea5
1 Parent(s): d2a2550

pushing model

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - Breakout-v5
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ library_name: cleanrl
8
+ model-index:
9
+ - name: PPO
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: Breakout-v5
16
+ type: Breakout-v5
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 755.60 +/- 175.27
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # (CleanRL) **PPO** Agent Playing **Breakout-v5**
25
+
26
+ This is a trained model of a PPO agent playing Breakout-v5.
27
+ The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
28
+ found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/cleanba_ppo_envpool_impala_atari_wrapper.py).
29
+
30
+ ## Get Started
31
+
32
+ To use this model, please install the `cleanrl` package with the following command:
33
+
34
+ ```
35
+ pip install "cleanrl[jax,envpool,atari]"
36
+ python -m cleanrl_utils.enjoy --exp-name cleanba_ppo_envpool_impala_atari_wrapper --env-id Breakout-v5
37
+ ```
38
+
39
+ Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
40
+
41
+
42
+ ## Command to reproduce the training
43
+
44
+ ```bash
45
+ curl -OL https://huggingface.co/cleanrl/Breakout-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/cleanba_ppo_envpool_impala_atari_wrapper.py
46
+ curl -OL https://huggingface.co/cleanrl/Breakout-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
47
+ curl -OL https://huggingface.co/cleanrl/Breakout-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
48
+ poetry install --all-extras
49
+ python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-device-ids 1 2 3 --track --wandb-project-name cleanba --save-model --upload-model --hf-entity cleanrl --env-id Breakout-v5 --seed 1
50
+ ```
51
+
52
+ # Hyperparameters
53
+ ```python
54
+ {'actor_device_ids': [0],
55
+ 'actor_devices': ['gpu:0'],
56
+ 'anneal_lr': True,
57
+ 'async_batch_size': 20,
58
+ 'async_update': 3,
59
+ 'batch_size': 15360,
60
+ 'capture_video': False,
61
+ 'clip_coef': 0.1,
62
+ 'concurrency': True,
63
+ 'cuda': True,
64
+ 'distributed': True,
65
+ 'ent_coef': 0.01,
66
+ 'env_id': 'Breakout-v5',
67
+ 'exp_name': 'cleanba_ppo_envpool_impala_atari_wrapper',
68
+ 'gae_lambda': 0.95,
69
+ 'gamma': 0.99,
70
+ 'global_learner_decices': ['gpu:1',
71
+ 'gpu:2',
72
+ 'gpu:3',
73
+ 'gpu:5',
74
+ 'gpu:6',
75
+ 'gpu:7'],
76
+ 'hf_entity': 'cleanrl',
77
+ 'learner_device_ids': [1, 2, 3],
78
+ 'learner_devices': ['gpu:1', 'gpu:2', 'gpu:3'],
79
+ 'learning_rate': 0.00025,
80
+ 'local_batch_size': 7680,
81
+ 'local_minibatch_size': 1920,
82
+ 'local_num_envs': 60,
83
+ 'local_rank': 0,
84
+ 'max_grad_norm': 0.5,
85
+ 'minibatch_size': 3840,
86
+ 'norm_adv': True,
87
+ 'num_envs': 120,
88
+ 'num_minibatches': 4,
89
+ 'num_steps': 128,
90
+ 'num_updates': 3255,
91
+ 'profile': False,
92
+ 'save_model': True,
93
+ 'seed': 1,
94
+ 'target_kl': None,
95
+ 'test_actor_learner_throughput': False,
96
+ 'torch_deterministic': True,
97
+ 'total_timesteps': 50000000,
98
+ 'track': True,
99
+ 'update_epochs': 4,
100
+ 'upload_model': True,
101
+ 'vf_coef': 0.5,
102
+ 'wandb_entity': None,
103
+ 'wandb_project_name': 'cleanba',
104
+ 'world_size': 2}
105
+ ```
106
+
cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b86724db89a4dc1b48a6e40e06a6a4fa5caed3c5f837640274925d44e5de2311
3
+ size 4364166
cleanba_ppo_envpool_impala_atari_wrapper.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import time
5
+ import uuid
6
+ from collections import deque
7
+ from distutils.util import strtobool
8
+ from functools import partial
9
+ from typing import Sequence
10
+
11
+ os.environ[
12
+ "XLA_PYTHON_CLIENT_MEM_FRACTION"
13
+ ] = "0.6" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
14
+ os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1"
15
+ import queue
16
+ import threading
17
+
18
+ import envpool
19
+ import flax
20
+ import flax.linen as nn
21
+ import gym
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import numpy as np
25
+ import optax
26
+ from flax.linen.initializers import constant, orthogonal
27
+ from flax.training.train_state import TrainState
28
+ from tensorboardX import SummaryWriter
29
+
30
+
31
+ def parse_args():
32
+ # fmt: off
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
35
+ help="the name of this experiment")
36
+ parser.add_argument("--seed", type=int, default=1,
37
+ help="seed of the experiment")
38
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
39
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
40
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
41
+ help="if toggled, cuda will be enabled by default")
42
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
43
+ help="if toggled, this experiment will be tracked with Weights and Biases")
44
+ parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
45
+ help="the wandb's project name")
46
+ parser.add_argument("--wandb-entity", type=str, default=None,
47
+ help="the entity (team) of wandb's project")
48
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
49
+ help="whether to capture videos of the agent performances (check out `videos` folder)")
50
+ parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
51
+ help="whether to save model into the `runs/{run_name}` folder")
52
+ parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
53
+ help="whether to upload the saved model to huggingface")
54
+ parser.add_argument("--hf-entity", type=str, default="",
55
+ help="the user or org name of the model repository from the Hugging Face Hub")
56
+
57
+ # Algorithm specific arguments
58
+ parser.add_argument("--env-id", type=str, default="Breakout-v5",
59
+ help="the id of the environment")
60
+ parser.add_argument("--total-timesteps", type=int, default=50000000,
61
+ help="total timesteps of the experiments")
62
+ parser.add_argument("--learning-rate", type=float, default=2.5e-4,
63
+ help="the learning rate of the optimizer")
64
+ parser.add_argument("--local-num-envs", type=int, default=60,
65
+ help="the number of parallel game environments")
66
+ parser.add_argument("--async-batch-size", type=int, default=20,
67
+ help="the envpool's batch size in the async mode")
68
+ parser.add_argument("--num-steps", type=int, default=128,
69
+ help="the number of steps to run in each environment per policy rollout")
70
+ parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
71
+ help="Toggle learning rate annealing for policy and value networks")
72
+ parser.add_argument("--gamma", type=float, default=0.99,
73
+ help="the discount factor gamma")
74
+ parser.add_argument("--gae-lambda", type=float, default=0.95,
75
+ help="the lambda for the general advantage estimation")
76
+ parser.add_argument("--num-minibatches", type=int, default=4,
77
+ help="the number of mini-batches")
78
+ parser.add_argument("--update-epochs", type=int, default=4,
79
+ help="the K epochs to update the policy")
80
+ parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
81
+ help="Toggles advantages normalization")
82
+ parser.add_argument("--clip-coef", type=float, default=0.1,
83
+ help="the surrogate clipping coefficient")
84
+ parser.add_argument("--ent-coef", type=float, default=0.01,
85
+ help="coefficient of the entropy")
86
+ parser.add_argument("--vf-coef", type=float, default=0.5,
87
+ help="coefficient of the value function")
88
+ parser.add_argument("--max-grad-norm", type=float, default=0.5,
89
+ help="the maximum norm for the gradient clipping")
90
+ parser.add_argument("--target-kl", type=float, default=None,
91
+ help="the target KL divergence threshold")
92
+
93
+ parser.add_argument("--actor-device-ids", type=int, nargs="+", default=[0], # type is actually List[int]
94
+ help="the device ids that actor workers will use (currently only support 1 device)")
95
+ parser.add_argument("--learner-device-ids", type=int, nargs="+", default=[0], # type is actually List[int]
96
+ help="the device ids that learner workers will use")
97
+ parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
98
+ help="whether to use `jax.distirbuted`")
99
+ parser.add_argument("--concurrency", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
100
+ help="whether to run the actor and learner concurrently")
101
+ parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
102
+ help="whether to call block_until_ready() for profiling")
103
+ parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
104
+ help="whether to test actor-learner throughput by removing the actor-learner communication")
105
+ args = parser.parse_args()
106
+ args.local_batch_size = int(args.local_num_envs * args.num_steps)
107
+ args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
108
+ args.num_updates = args.total_timesteps // args.local_batch_size
109
+ args.async_update = int(args.local_num_envs / args.async_batch_size)
110
+ assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
111
+ # fmt: on
112
+ return args
113
+
114
+
115
+ ATARI_MAX_FRAMES = int(
116
+ 108000 / 4
117
+ ) # 108000 is the max number of frames in an Atari game, divided by 4 to account for frame skipping
118
+
119
+
120
+ def make_env(env_id, seed, num_envs, async_batch_size=1):
121
+ def thunk():
122
+ envs = envpool.make(
123
+ env_id,
124
+ env_type="gym",
125
+ num_envs=num_envs,
126
+ batch_size=async_batch_size,
127
+ episodic_life=True, # Espeholt et al., 2018, Tab. G.1
128
+ repeat_action_probability=0, # Hessel et al., 2022 (Muesli) Tab. 10
129
+ noop_max=30, # Espeholt et al., 2018, Tab. C.1 "Up to 30 no-ops at the beginning of each episode."
130
+ full_action_space=False, # Espeholt et al., 2018, Appendix G., "Following related work, experts use game-specific action sets."
131
+ max_episode_steps=ATARI_MAX_FRAMES, # Hessel et al. 2018 (Rainbow DQN), Table 3, Max frames per episode
132
+ reward_clip=True,
133
+ seed=seed,
134
+ )
135
+ envs.num_envs = num_envs
136
+ envs.single_action_space = envs.action_space
137
+ envs.single_observation_space = envs.observation_space
138
+ envs.is_vector_env = True
139
+ return envs
140
+
141
+ return thunk
142
+
143
+
144
+ class ResidualBlock(nn.Module):
145
+ channels: int
146
+
147
+ @nn.compact
148
+ def __call__(self, x):
149
+ inputs = x
150
+ x = nn.relu(x)
151
+ x = nn.Conv(
152
+ self.channels,
153
+ kernel_size=(3, 3),
154
+ )(x)
155
+ x = nn.relu(x)
156
+ x = nn.Conv(
157
+ self.channels,
158
+ kernel_size=(3, 3),
159
+ )(x)
160
+ return x + inputs
161
+
162
+
163
+ class ConvSequence(nn.Module):
164
+ channels: int
165
+
166
+ @nn.compact
167
+ def __call__(self, x):
168
+ x = nn.Conv(
169
+ self.channels,
170
+ kernel_size=(3, 3),
171
+ )(x)
172
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
173
+ x = ResidualBlock(self.channels)(x)
174
+ x = ResidualBlock(self.channels)(x)
175
+ return x
176
+
177
+
178
+ class Network(nn.Module):
179
+ channelss: Sequence[int] = (16, 32, 32)
180
+
181
+ @nn.compact
182
+ def __call__(self, x):
183
+ x = jnp.transpose(x, (0, 2, 3, 1))
184
+ x = x / (255.0)
185
+ for channels in self.channelss:
186
+ x = ConvSequence(channels)(x)
187
+ x = nn.relu(x)
188
+ x = x.reshape((x.shape[0], -1))
189
+ x = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
190
+ x = nn.relu(x)
191
+ return x
192
+
193
+
194
+ class Critic(nn.Module):
195
+ @nn.compact
196
+ def __call__(self, x):
197
+ return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x)
198
+
199
+
200
+ class Actor(nn.Module):
201
+ action_dim: int
202
+
203
+ @nn.compact
204
+ def __call__(self, x):
205
+ return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)
206
+
207
+
208
+ @flax.struct.dataclass
209
+ class AgentParams:
210
+ network_params: flax.core.FrozenDict
211
+ actor_params: flax.core.FrozenDict
212
+ critic_params: flax.core.FrozenDict
213
+
214
+
215
+ @partial(jax.jit, static_argnums=(3))
216
+ def get_action_and_value(
217
+ params: flax.core.FrozenDict,
218
+ next_obs: np.ndarray,
219
+ key: jax.random.PRNGKey,
220
+ action_dim: int,
221
+ ):
222
+ next_obs = jnp.array(next_obs)
223
+ hidden = Network().apply(params.network_params, next_obs)
224
+ logits = Actor(action_dim).apply(params.actor_params, hidden)
225
+ # sample action: Gumbel-softmax trick
226
+ # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
227
+ key, subkey = jax.random.split(key)
228
+ u = jax.random.uniform(subkey, shape=logits.shape)
229
+ action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
230
+ logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
231
+ value = Critic().apply(params.critic_params, hidden)
232
+ return next_obs, action, logprob, value.squeeze(), key
233
+
234
+
235
+ def prepare_data(
236
+ obs: list,
237
+ dones: list,
238
+ values: list,
239
+ actions: list,
240
+ logprobs: list,
241
+ env_ids: list,
242
+ rewards: list,
243
+ ):
244
+ obs = jnp.asarray(obs)
245
+ dones = jnp.asarray(dones)
246
+ values = jnp.asarray(values)
247
+ actions = jnp.asarray(actions)
248
+ logprobs = jnp.asarray(logprobs)
249
+ env_ids = jnp.asarray(env_ids)
250
+ rewards = jnp.asarray(rewards)
251
+
252
+ # TODO: in an unlikely event, one of the envs might have not stepped at all, which may results in unexpected behavior
253
+ T, B = env_ids.shape
254
+ index_ranges = jnp.arange(T * B, dtype=jnp.int32)
255
+ next_index_ranges = jnp.zeros_like(index_ranges, dtype=jnp.int32)
256
+ last_env_ids = jnp.zeros(args.local_num_envs, dtype=jnp.int32) - 1
257
+
258
+ def f(carry, x):
259
+ last_env_ids, next_index_ranges = carry
260
+ env_id, index_range = x
261
+ next_index_ranges = next_index_ranges.at[last_env_ids[env_id]].set(
262
+ jnp.where(last_env_ids[env_id] != -1, index_range, next_index_ranges[last_env_ids[env_id]])
263
+ )
264
+ last_env_ids = last_env_ids.at[env_id].set(index_range)
265
+ return (last_env_ids, next_index_ranges), None
266
+
267
+ (last_env_ids, next_index_ranges), _ = jax.lax.scan(
268
+ f,
269
+ (last_env_ids, next_index_ranges),
270
+ (env_ids.reshape(-1), index_ranges),
271
+ )
272
+
273
+ # rewards is off by one time step
274
+ rewards = rewards.reshape(-1)[next_index_ranges].reshape((args.num_steps) * args.async_update, args.async_batch_size)
275
+ advantages, returns, _, final_env_ids = compute_gae(env_ids, rewards, values, dones)
276
+ # b_inds = jnp.nonzero(final_env_ids.reshape(-1), size=(args.num_steps) * args.async_update * args.async_batch_size)[0] # useful for debugging
277
+ b_obs = obs.reshape((-1,) + obs.shape[2:])
278
+ b_actions = actions.reshape(-1)
279
+ b_logprobs = logprobs.reshape(-1)
280
+ b_advantages = advantages.reshape(-1)
281
+ b_returns = returns.reshape(-1)
282
+ return b_obs, b_actions, b_logprobs, b_advantages, b_returns
283
+
284
+
285
+ @jax.jit
286
+ def make_bulk_array(
287
+ obs: list,
288
+ values: list,
289
+ actions: list,
290
+ logprobs: list,
291
+ ):
292
+ obs = jnp.asarray(obs)
293
+ values = jnp.asarray(values)
294
+ actions = jnp.asarray(actions)
295
+ logprobs = jnp.asarray(logprobs)
296
+ return obs, values, actions, logprobs
297
+
298
+
299
+ def rollout(
300
+ key: jax.random.PRNGKey,
301
+ args,
302
+ rollout_queue,
303
+ params_queue: queue.Queue,
304
+ writer,
305
+ learner_devices,
306
+ ):
307
+ envs = make_env(args.env_id, args.seed + jax.process_index(), args.local_num_envs, args.async_batch_size)()
308
+ len_actor_device_ids = len(args.actor_device_ids)
309
+ global_step = 0
310
+ # TRY NOT TO MODIFY: start the game
311
+ start_time = time.time()
312
+
313
+ # put data in the last index
314
+ episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32)
315
+ returned_episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32)
316
+ episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32)
317
+ returned_episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32)
318
+ envs.async_reset()
319
+
320
+ params_queue_get_time = deque(maxlen=10)
321
+ rollout_time = deque(maxlen=10)
322
+ rollout_queue_put_time = deque(maxlen=10)
323
+ actor_policy_version = 0
324
+ for update in range(1, args.num_updates + 2):
325
+ # NOTE: This is a major difference from the sync version:
326
+ # at the end of the rollout phase, the sync version will have the next observation
327
+ # ready for the value bootstrap, but the async version will not have it.
328
+ # for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
329
+ # but note that the extra states are not used for the loss computation in the next iteration,
330
+ # while the sync version will use the extra state for the loss computation.
331
+ update_time_start = time.time()
332
+ obs = []
333
+ dones = []
334
+ actions = []
335
+ logprobs = []
336
+ values = []
337
+ env_ids = []
338
+ rewards = []
339
+ truncations = []
340
+ terminations = []
341
+ env_recv_time = 0
342
+ inference_time = 0
343
+ storage_time = 0
344
+ env_send_time = 0
345
+
346
+ # NOTE: `update != 2` is actually IMPORTANT — it allows us to start running policy collection
347
+ # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
348
+ # behind the learner's policy version
349
+ params_queue_get_time_start = time.time()
350
+ if not args.concurrency:
351
+ params = params_queue.get()
352
+ actor_policy_version += 1
353
+ else:
354
+ if update != 2:
355
+ params = params_queue.get()
356
+ actor_policy_version += 1
357
+ params_queue_get_time.append(time.time() - params_queue_get_time_start)
358
+ writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
359
+ rollout_time_start = time.time()
360
+ for _ in range(
361
+ args.async_update, (args.num_steps + 1) * args.async_update
362
+ ): # num_steps + 1 to get the states for value bootstrapping.
363
+ env_recv_time_start = time.time()
364
+ next_obs, next_reward, next_done, info = envs.recv()
365
+ env_recv_time += time.time() - env_recv_time_start
366
+ global_step += len(next_done) * len_actor_device_ids * args.world_size
367
+ env_id = info["env_id"]
368
+
369
+ inference_time_start = time.time()
370
+ next_obs, action, logprob, value, key = get_action_and_value(params, next_obs, key, envs.single_action_space.n)
371
+ inference_time += time.time() - inference_time_start
372
+
373
+ env_send_time_start = time.time()
374
+ envs.send(np.array(action), env_id)
375
+ env_send_time += time.time() - env_send_time_start
376
+ storage_time_start = time.time()
377
+ obs.append(next_obs)
378
+ dones.append(next_done)
379
+ values.append(value)
380
+ actions.append(action)
381
+ logprobs.append(logprob)
382
+ env_ids.append(env_id)
383
+ rewards.append(next_reward)
384
+
385
+ # info["TimeLimit.truncated"] has a bug https://github.com/sail-sg/envpool/issues/239
386
+ # so we use our own truncated flag
387
+ truncated = info["elapsed_step"] >= envs.spec.config.max_episode_steps
388
+ truncations.append(truncated)
389
+ terminations.append(info["terminated"])
390
+ episode_returns[env_id] += info["reward"]
391
+ returned_episode_returns[env_id] = np.where(
392
+ info["terminated"] + truncated, episode_returns[env_id], returned_episode_returns[env_id]
393
+ )
394
+ episode_returns[env_id] *= (1 - info["terminated"]) * (1 - truncated)
395
+ episode_lengths[env_id] += 1
396
+ returned_episode_lengths[env_id] = np.where(
397
+ info["terminated"] + truncated, episode_lengths[env_id], returned_episode_lengths[env_id]
398
+ )
399
+ episode_lengths[env_id] *= (1 - info["terminated"]) * (1 - truncated)
400
+ storage_time += time.time() - storage_time_start
401
+ if args.profile:
402
+ action.block_until_ready()
403
+ rollout_time.append(time.time() - rollout_time_start)
404
+ writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
405
+
406
+ avg_episodic_return = np.mean(returned_episode_returns)
407
+ writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
408
+ writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step)
409
+ print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")
410
+ print("SPS:", int(global_step / (time.time() - start_time)))
411
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
412
+
413
+ writer.add_scalar("stats/truncations", np.sum(truncations), global_step)
414
+ writer.add_scalar("stats/terminations", np.sum(terminations), global_step)
415
+ writer.add_scalar("stats/env_recv_time", env_recv_time, global_step)
416
+ writer.add_scalar("stats/inference_time", inference_time, global_step)
417
+ writer.add_scalar("stats/storage_time", storage_time, global_step)
418
+ writer.add_scalar("stats/env_send_time", env_send_time, global_step)
419
+ # `make_bulk_array` is actually important. It accumulates the data from the lists
420
+ # into single bulk arrays, which later makes transferring the data to the learner's
421
+ # device slightly faster. See https://wandb.ai/costa-huang/cleanRL/reports/data-transfer-optimization--VmlldzozNjU5MTg1
422
+ if args.learner_device_ids[0] != args.actor_device_ids[0]:
423
+ obs, values, actions, logprobs = make_bulk_array(
424
+ obs,
425
+ values,
426
+ actions,
427
+ logprobs,
428
+ )
429
+
430
+ payload = (
431
+ global_step,
432
+ actor_policy_version,
433
+ update,
434
+ obs,
435
+ values,
436
+ actions,
437
+ logprobs,
438
+ dones,
439
+ env_ids,
440
+ rewards,
441
+ np.mean(params_queue_get_time),
442
+ )
443
+ if update == 1 or not args.test_actor_learner_throughput:
444
+ rollout_queue_put_time_start = time.time()
445
+ rollout_queue.put(payload)
446
+ rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)
447
+ writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step)
448
+
449
+ writer.add_scalar(
450
+ "charts/SPS_update",
451
+ int(
452
+ args.local_num_envs
453
+ * args.num_steps
454
+ * len_actor_device_ids
455
+ * args.world_size
456
+ / (time.time() - update_time_start)
457
+ ),
458
+ global_step,
459
+ )
460
+
461
+
462
+ @partial(jax.jit, static_argnums=(3))
463
+ def get_action_and_value2(
464
+ params: flax.core.FrozenDict,
465
+ x: np.ndarray,
466
+ action: np.ndarray,
467
+ action_dim: int,
468
+ ):
469
+ hidden = Network().apply(params.network_params, x)
470
+ logits = Actor(action_dim).apply(params.actor_params, hidden)
471
+ logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
472
+ logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
473
+ logits = logits.clip(min=jnp.finfo(logits.dtype).min)
474
+ p_log_p = logits * jax.nn.softmax(logits)
475
+ entropy = -p_log_p.sum(-1)
476
+ value = Critic().apply(params.critic_params, hidden).squeeze()
477
+ return logprob, entropy, value
478
+
479
+
480
+ @jax.jit
481
+ def compute_gae(
482
+ env_ids: np.ndarray,
483
+ rewards: np.ndarray,
484
+ values: np.ndarray,
485
+ dones: np.ndarray,
486
+ ):
487
+ dones = jnp.asarray(dones)
488
+ values = jnp.asarray(values)
489
+ env_ids = jnp.asarray(env_ids)
490
+ rewards = jnp.asarray(rewards)
491
+
492
+ _, B = env_ids.shape
493
+ final_env_id_checked = jnp.zeros(args.local_num_envs, jnp.int32) - 1
494
+ final_env_ids = jnp.zeros(B, jnp.int32)
495
+ advantages = jnp.zeros(B)
496
+ lastgaelam = jnp.zeros(args.local_num_envs)
497
+ lastdones = jnp.zeros(args.local_num_envs) + 1
498
+ lastvalues = jnp.zeros(args.local_num_envs)
499
+
500
+ def compute_gae_once(carry, x):
501
+ lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked = carry
502
+ (
503
+ done,
504
+ value,
505
+ eid,
506
+ reward,
507
+ ) = x
508
+ nextnonterminal = 1.0 - lastdones[eid]
509
+ nextvalues = lastvalues[eid]
510
+ delta = jnp.where(final_env_id_checked[eid] == -1, 0, reward + args.gamma * nextvalues * nextnonterminal - value)
511
+ advantages = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam[eid]
512
+ final_env_ids = jnp.where(final_env_id_checked[eid] == 1, 1, 0)
513
+ final_env_id_checked = final_env_id_checked.at[eid].set(
514
+ jnp.where(final_env_id_checked[eid] == -1, 1, final_env_id_checked[eid])
515
+ )
516
+
517
+ # the last_ variables keeps track of the actual `num_steps`
518
+ lastgaelam = lastgaelam.at[eid].set(advantages)
519
+ lastdones = lastdones.at[eid].set(done)
520
+ lastvalues = lastvalues.at[eid].set(value)
521
+ return (lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked), (
522
+ advantages,
523
+ final_env_ids,
524
+ )
525
+
526
+ (_, _, _, _, final_env_ids, final_env_id_checked), (advantages, final_env_ids) = jax.lax.scan(
527
+ compute_gae_once,
528
+ (
529
+ lastvalues,
530
+ lastdones,
531
+ advantages,
532
+ lastgaelam,
533
+ final_env_ids,
534
+ final_env_id_checked,
535
+ ),
536
+ (
537
+ dones,
538
+ values,
539
+ env_ids,
540
+ rewards,
541
+ ),
542
+ reverse=True,
543
+ )
544
+ return advantages, advantages + values, final_env_id_checked, final_env_ids
545
+
546
+
547
+ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, action_dim):
548
+ newlogprob, entropy, newvalue = get_action_and_value2(params, x, a, action_dim)
549
+ logratio = newlogprob - logp
550
+ ratio = jnp.exp(logratio)
551
+ approx_kl = ((ratio - 1) - logratio).mean()
552
+
553
+ if args.norm_adv:
554
+ mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
555
+
556
+ # Policy loss
557
+ pg_loss1 = -mb_advantages * ratio
558
+ pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
559
+ pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()
560
+
561
+ # Value loss
562
+ v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()
563
+
564
+ entropy_loss = entropy.mean()
565
+ loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
566
+ return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
567
+
568
+
569
+ @partial(jax.jit, static_argnums=(6))
570
+ def single_device_update(
571
+ agent_state: TrainState,
572
+ b_obs,
573
+ b_actions,
574
+ b_logprobs,
575
+ b_advantages,
576
+ b_returns,
577
+ action_dim,
578
+ key: jax.random.PRNGKey,
579
+ ):
580
+ ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
581
+
582
+ def update_epoch(carry, _):
583
+ agent_state, key = carry
584
+ key, subkey = jax.random.split(key)
585
+
586
+ # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py
587
+ def convert_data(x: jnp.ndarray):
588
+ x = jax.random.permutation(subkey, x)
589
+ x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:])
590
+ return x
591
+
592
+ def update_minibatch(agent_state, minibatch):
593
+ mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns = minibatch
594
+ (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
595
+ agent_state.params,
596
+ mb_obs,
597
+ mb_actions,
598
+ mb_logprobs,
599
+ mb_advantages,
600
+ mb_returns,
601
+ action_dim,
602
+ )
603
+ grads = jax.lax.pmean(grads, axis_name="local_devices")
604
+ agent_state = agent_state.apply_gradients(grads=grads)
605
+ return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads)
606
+
607
+ agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan(
608
+ update_minibatch,
609
+ agent_state,
610
+ (
611
+ convert_data(b_obs),
612
+ convert_data(b_actions),
613
+ convert_data(b_logprobs),
614
+ convert_data(b_advantages),
615
+ convert_data(b_returns),
616
+ ),
617
+ )
618
+ return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads)
619
+
620
+ (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, _) = jax.lax.scan(
621
+ update_epoch, (agent_state, key), (), length=args.update_epochs
622
+ )
623
+ return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
624
+
625
+
626
+ if __name__ == "__main__":
627
+ args = parse_args()
628
+ if args.distributed:
629
+ jax.distributed.initialize(
630
+ local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
631
+ )
632
+ print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
633
+
634
+ args.world_size = jax.process_count()
635
+ args.local_rank = jax.process_index()
636
+ args.num_envs = args.local_num_envs * args.world_size
637
+ args.batch_size = args.local_batch_size * args.world_size
638
+ args.minibatch_size = args.local_minibatch_size * args.world_size
639
+ args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
640
+ args.async_update = int(args.local_num_envs / args.async_batch_size)
641
+ local_devices = jax.local_devices()
642
+ global_devices = jax.devices()
643
+ learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
644
+ actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
645
+ global_learner_decices = [
646
+ global_devices[d_id + process_index * len(local_devices)]
647
+ for process_index in range(args.world_size)
648
+ for d_id in args.learner_device_ids
649
+ ]
650
+ print("global_learner_decices", global_learner_decices)
651
+ args.global_learner_decices = [str(item) for item in global_learner_decices]
652
+ args.actor_devices = [str(item) for item in actor_devices]
653
+ args.learner_devices = [str(item) for item in learner_devices]
654
+
655
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{uuid.uuid4()}"
656
+ if args.track and args.local_rank == 0:
657
+ import wandb
658
+
659
+ wandb.init(
660
+ project=args.wandb_project_name,
661
+ entity=args.wandb_entity,
662
+ sync_tensorboard=True,
663
+ config=vars(args),
664
+ name=run_name,
665
+ monitor_gym=True,
666
+ save_code=True,
667
+ )
668
+ writer = SummaryWriter(f"runs/{run_name}")
669
+ writer.add_text(
670
+ "hyperparameters",
671
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
672
+ )
673
+
674
+ # TRY NOT TO MODIFY: seeding
675
+ random.seed(args.seed)
676
+ np.random.seed(args.seed)
677
+ key = jax.random.PRNGKey(args.seed)
678
+ key, network_key, actor_key, critic_key = jax.random.split(key, 4)
679
+
680
+ # env setup
681
+ envs = make_env(args.env_id, args.seed, args.local_num_envs, args.async_batch_size)()
682
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
683
+
684
+ def linear_schedule(count):
685
+ # anneal learning rate linearly after one training iteration which contains
686
+ # (args.num_minibatches * args.update_epochs) gradient updates
687
+ frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
688
+ return args.learning_rate * frac
689
+
690
+ network = Network()
691
+ actor = Actor(action_dim=envs.single_action_space.n)
692
+ critic = Critic()
693
+ network_params = network.init(network_key, np.array([envs.single_observation_space.sample()]))
694
+ agent_state = TrainState.create(
695
+ apply_fn=None,
696
+ params=AgentParams(
697
+ network_params,
698
+ actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
699
+ critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
700
+ ),
701
+ tx=optax.chain(
702
+ optax.clip_by_global_norm(args.max_grad_norm),
703
+ optax.inject_hyperparams(optax.adam)(
704
+ learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
705
+ ),
706
+ ),
707
+ )
708
+ agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
709
+
710
+ multi_device_update = jax.pmap(
711
+ single_device_update,
712
+ axis_name="local_devices",
713
+ devices=global_learner_decices,
714
+ in_axes=(0, 0, 0, 0, 0, 0, None, None),
715
+ out_axes=(0, 0, 0, 0, 0, 0, None),
716
+ static_broadcasted_argnums=(6),
717
+ )
718
+
719
+ rollout_queue = queue.Queue(maxsize=1)
720
+ params_queues = []
721
+ for d_idx, d_id in enumerate(args.actor_device_ids):
722
+ params_queue = queue.Queue(maxsize=1)
723
+ params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id]))
724
+ threading.Thread(
725
+ target=rollout,
726
+ args=(
727
+ jax.device_put(key, local_devices[d_id]),
728
+ args,
729
+ rollout_queue,
730
+ params_queue,
731
+ writer,
732
+ learner_devices,
733
+ ),
734
+ ).start()
735
+ params_queues.append(params_queue)
736
+
737
+ rollout_queue_get_time = deque(maxlen=10)
738
+ data_transfer_time = deque(maxlen=10)
739
+ learner_policy_version = 0
740
+ prepare_data = jax.jit(prepare_data, device=learner_devices[0])
741
+ while True:
742
+ learner_policy_version += 1
743
+ if learner_policy_version == 1 or not args.test_actor_learner_throughput:
744
+ rollout_queue_get_time_start = time.time()
745
+ (
746
+ global_step,
747
+ actor_policy_version,
748
+ update,
749
+ obs,
750
+ values,
751
+ actions,
752
+ logprobs,
753
+ dones,
754
+ env_ids,
755
+ rewards,
756
+ avg_params_queue_get_time,
757
+ ) = rollout_queue.get()
758
+ rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
759
+ writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
760
+ writer.add_scalar(
761
+ "stats/rollout_params_queue_get_time_diff",
762
+ np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
763
+ global_step,
764
+ )
765
+
766
+ data_transfer_time_start = time.time()
767
+ b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data(
768
+ obs,
769
+ dones,
770
+ values,
771
+ actions,
772
+ logprobs,
773
+ env_ids,
774
+ rewards,
775
+ )
776
+ b_obs = jnp.array_split(b_obs, len(learner_devices))
777
+ b_actions = jnp.array_split(b_actions, len(learner_devices))
778
+ b_logprobs = jnp.array_split(b_logprobs, len(learner_devices))
779
+ b_advantages = jnp.array_split(b_advantages, len(learner_devices))
780
+ b_returns = jnp.array_split(b_returns, len(learner_devices))
781
+ data_transfer_time.append(time.time() - data_transfer_time_start)
782
+ writer.add_scalar("stats/data_transfer_time", np.mean(data_transfer_time), global_step)
783
+
784
+ training_time_start = time.time()
785
+ (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key) = multi_device_update(
786
+ agent_state,
787
+ jax.device_put_sharded(b_obs, learner_devices),
788
+ jax.device_put_sharded(b_actions, learner_devices),
789
+ jax.device_put_sharded(b_logprobs, learner_devices),
790
+ jax.device_put_sharded(b_advantages, learner_devices),
791
+ jax.device_put_sharded(b_returns, learner_devices),
792
+ envs.single_action_space.n,
793
+ key,
794
+ )
795
+ if learner_policy_version == 1 or not args.test_actor_learner_throughput:
796
+ for d_idx, d_id in enumerate(args.actor_device_ids):
797
+ params_queues[d_idx].put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id]))
798
+ if args.profile:
799
+ v_loss[-1, -1, -1].block_until_ready()
800
+ writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
801
+ writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
802
+ writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
803
+ print(
804
+ global_step,
805
+ f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
806
+ )
807
+
808
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
809
+ writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
810
+ writer.add_scalar("losses/value_loss", v_loss[-1, -1, -1].item(), global_step)
811
+ writer.add_scalar("losses/policy_loss", pg_loss[-1, -1, -1].item(), global_step)
812
+ writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
813
+ writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step)
814
+ writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step)
815
+ if update >= args.num_updates:
816
+ break
817
+
818
+ if args.save_model and args.local_rank == 0:
819
+ if args.distributed:
820
+ jax.distributed.shutdown()
821
+ agent_state = flax.jax_utils.unreplicate(agent_state)
822
+ model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
823
+ with open(model_path, "wb") as f:
824
+ f.write(
825
+ flax.serialization.to_bytes(
826
+ [
827
+ vars(args),
828
+ [
829
+ agent_state.params.network_params,
830
+ agent_state.params.actor_params,
831
+ agent_state.params.critic_params,
832
+ ],
833
+ ]
834
+ )
835
+ )
836
+ print(f"model saved to {model_path}")
837
+ from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
838
+
839
+ episodic_returns = evaluate(
840
+ model_path,
841
+ make_env,
842
+ args.env_id,
843
+ eval_episodes=10,
844
+ run_name=f"{run_name}-eval",
845
+ Model=(Network, Actor, Critic),
846
+ )
847
+ for idx, episodic_return in enumerate(episodic_returns):
848
+ writer.add_scalar("eval/episodic_return", episodic_return, idx)
849
+
850
+ if args.upload_model:
851
+ from cleanrl_utils.huggingface import push_to_hub
852
+
853
+ repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
854
+ repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
855
+ push_to_hub(
856
+ args,
857
+ episodic_returns,
858
+ repo_id,
859
+ "PPO",
860
+ f"runs/{run_name}",
861
+ f"videos/{run_name}-eval",
862
+ extra_dependencies=["jax", "envpool", "atari"],
863
+ )
864
+
865
+ envs.close()
866
+ writer.close()
events.out.tfevents.1678205984.ip-26-0-135-190 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bcf68cf8cb14aa39374963552fcfc8ce2d993c0e0e5c44c11df9dbb3631b36a
3
+ size 5017749
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cleanba"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Costa Huang <costa.huang@outlook.com>"]
6
+ readme = "README.md"
7
+ packages = [
8
+ { include = "cleanba" },
9
+ { include = "cleanrl_utils" },
10
+ ]
11
+
12
+ [tool.poetry.dependencies]
13
+ python = "^3.8"
14
+ tensorboard = "^2.12.0"
15
+ envpool = "^0.8.1"
16
+ jax = "0.3.25"
17
+ flax = "0.6.0"
18
+ optax = "0.1.3"
19
+ huggingface-hub = "^0.12.0"
20
+ jaxlib = "0.3.25"
21
+ wandb = "^0.13.10"
22
+ tensorboardx = "^2.5.1"
23
+ chex = "0.1.5"
24
+ gym = "0.23.1"
25
+ opencv-python = "^4.7.0.68"
26
+ moviepy = "^1.0.3"
27
+
28
+
29
+ [tool.poetry.group.dev.dependencies]
30
+ pre-commit = "^3.0.4"
31
+
32
+ [build-system]
33
+ requires = ["poetry-core"]
34
+ build-backend = "poetry.core.masonry.api"
replay.mp4 ADDED
Binary file (817 kB). View file
 
videos/Breakout-v5__cleanba_ppo_envpool_impala_atari_wrapper__1__50845a10-8df4-40d7-a497-479aab048040-eval/0.mp4 ADDED
Binary file (817 kB). View file