sdpkjc commited on
Commit
3ac7da4
1 Parent(s): 336bca7

pushing model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ videos/HalfCheetah-v4__ppo_fix_continuous_action__5__1702935608-eval/rl-video-episode-8.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ videos/HalfCheetah-v4__ppo_fix_continuous_action__5__1702935608-eval/rl-video-episode-1.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ videos/HalfCheetah-v4__ppo_fix_continuous_action__5__1702935608-eval/rl-video-episode-0.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - HalfCheetah-v4
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ library_name: cleanrl
8
+ model-index:
9
+ - name: PPO
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: HalfCheetah-v4
16
+ type: HalfCheetah-v4
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 1636.48 +/- 11.27
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # (CleanRL) **PPO** Agent Playing **HalfCheetah-v4**
25
+
26
+ This is a trained model of a PPO agent playing HalfCheetah-v4.
27
+ The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
28
+ found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_fix_continuous_action.py).
29
+
30
+ ## Get Started
31
+
32
+ To use this model, please install the `cleanrl` package with the following command:
33
+
34
+ ```
35
+ pip install "cleanrl[ppo_fix_continuous_action]"
36
+ python -m cleanrl_utils.enjoy --exp-name ppo_fix_continuous_action --env-id HalfCheetah-v4
37
+ ```
38
+
39
+ Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
40
+
41
+
42
+ ## Command to reproduce the training
43
+
44
+ ```bash
45
+ curl -OL https://huggingface.co/sdpkjc/HalfCheetah-v4-ppo_fix_continuous_action-seed5/raw/main/ppo_fix_continuous_action.py
46
+ curl -OL https://huggingface.co/sdpkjc/HalfCheetah-v4-ppo_fix_continuous_action-seed5/raw/main/pyproject.toml
47
+ curl -OL https://huggingface.co/sdpkjc/HalfCheetah-v4-ppo_fix_continuous_action-seed5/raw/main/poetry.lock
48
+ poetry install --all-extras
49
+ python ppo_fix_continuous_action.py --save-model --upload-model --hf-entity sdpkjc --env-id HalfCheetah-v4 --seed 5 --track
50
+ ```
51
+
52
+ # Hyperparameters
53
+ ```python
54
+ {'anneal_lr': True,
55
+ 'batch_size': 2048,
56
+ 'capture_video': False,
57
+ 'clip_coef': 0.2,
58
+ 'clip_vloss': True,
59
+ 'cuda': True,
60
+ 'ent_coef': 0.0,
61
+ 'env_id': 'HalfCheetah-v4',
62
+ 'exp_name': 'ppo_fix_continuous_action',
63
+ 'gae_lambda': 0.95,
64
+ 'gamma': 0.99,
65
+ 'hf_entity': 'sdpkjc',
66
+ 'learning_rate': 0.0003,
67
+ 'max_grad_norm': 0.5,
68
+ 'minibatch_size': 64,
69
+ 'norm_adv': True,
70
+ 'num_envs': 1,
71
+ 'num_minibatches': 32,
72
+ 'num_steps': 2048,
73
+ 'save_model': True,
74
+ 'seed': 5,
75
+ 'target_kl': None,
76
+ 'torch_deterministic': True,
77
+ 'total_timesteps': 1000000,
78
+ 'track': True,
79
+ 'update_epochs': 10,
80
+ 'upload_model': True,
81
+ 'vf_coef': 0.5,
82
+ 'wandb_entity': None,
83
+ 'wandb_project_name': 'cleanRL'}
84
+ ```
85
+
events.out.tfevents.1702935616.4090-171.247025.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68a90c2f6933c56e324f6bbd1b410ffced6642a0a74d143904d7b7d28491b26a
3
+ size 376394
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
ppo_fix_continuous_action.cleanrl_model ADDED
Binary file (50.2 kB). View file
 
ppo_fix_continuous_action.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy
2
+ import argparse
3
+ import copy
4
+ import os
5
+ import random
6
+ import time
7
+ from distutils.util import strtobool
8
+ from typing import Callable
9
+
10
+ import gymnasium as gym
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ from torch.distributions.normal import Normal
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+
19
+ def parse_args():
20
+ # fmt: off
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
23
+ help="the name of this experiment")
24
+ parser.add_argument("--seed", type=int, default=1,
25
+ help="seed of the experiment")
26
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
27
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
28
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
29
+ help="if toggled, cuda will be enabled by default")
30
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
31
+ help="if toggled, this experiment will be tracked with Weights and Biases")
32
+ parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
33
+ help="the wandb's project name")
34
+ parser.add_argument("--wandb-entity", type=str, default=None,
35
+ help="the entity (team) of wandb's project")
36
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
37
+ help="whether to capture videos of the agent performances (check out `videos` folder)")
38
+ parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
39
+ help="whether to save model into the `runs/{run_name}` folder")
40
+ parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
41
+ help="whether to upload the saved model to huggingface")
42
+ parser.add_argument("--hf-entity", type=str, default="",
43
+ help="the user or org name of the model repository from the Hugging Face Hub")
44
+
45
+ # Algorithm specific arguments
46
+ parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
47
+ help="the id of the environment")
48
+ parser.add_argument("--total-timesteps", type=int, default=1000000,
49
+ help="total timesteps of the experiments")
50
+ parser.add_argument("--learning-rate", type=float, default=3e-4,
51
+ help="the learning rate of the optimizer")
52
+ parser.add_argument("--num-envs", type=int, default=1,
53
+ help="the number of parallel game environments")
54
+ parser.add_argument("--num-steps", type=int, default=2048,
55
+ help="the number of steps to run in each environment per policy rollout")
56
+ parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
57
+ help="Toggle learning rate annealing for policy and value networks")
58
+ parser.add_argument("--gamma", type=float, default=0.99,
59
+ help="the discount factor gamma")
60
+ parser.add_argument("--gae-lambda", type=float, default=0.95,
61
+ help="the lambda for the general advantage estimation")
62
+ parser.add_argument("--num-minibatches", type=int, default=32,
63
+ help="the number of mini-batches")
64
+ parser.add_argument("--update-epochs", type=int, default=10,
65
+ help="the K epochs to update the policy")
66
+ parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
67
+ help="Toggles advantages normalization")
68
+ parser.add_argument("--clip-coef", type=float, default=0.2,
69
+ help="the surrogate clipping coefficient")
70
+ parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
71
+ help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
72
+ parser.add_argument("--ent-coef", type=float, default=0.0,
73
+ help="coefficient of the entropy")
74
+ parser.add_argument("--vf-coef", type=float, default=0.5,
75
+ help="coefficient of the value function")
76
+ parser.add_argument("--max-grad-norm", type=float, default=0.5,
77
+ help="the maximum norm for the gradient clipping")
78
+ parser.add_argument("--target-kl", type=float, default=None,
79
+ help="the target KL divergence threshold")
80
+ args = parser.parse_args()
81
+ args.batch_size = int(args.num_envs * args.num_steps)
82
+ args.minibatch_size = int(args.batch_size // args.num_minibatches)
83
+ # fmt: on
84
+ return args
85
+
86
+
87
+ # https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/normalize.py
88
+ class RunningMeanStd(nn.Module):
89
+ def __init__(self, epsilon=1e-4, shape=()):
90
+ super().__init__()
91
+ self.register_buffer("mean", torch.zeros(shape, dtype=torch.float64))
92
+ self.register_buffer("var", torch.ones(shape, dtype=torch.float64))
93
+ self.register_buffer("count", torch.tensor(epsilon, dtype=torch.float64))
94
+
95
+ def update(self, x):
96
+ x = torch.as_tensor(x, dtype=torch.float64).to(self.mean.device)
97
+ batch_mean = torch.mean(x, dim=0).to(self.mean.device)
98
+ batch_var = torch.var(x, dim=0, unbiased=False).to(self.mean.device)
99
+ batch_count = x.shape[0]
100
+
101
+ self.mean, self.var, self.count = update_mean_var_count_from_moments(
102
+ self.mean, self.var, self.count, batch_mean, batch_var, batch_count
103
+ )
104
+
105
+
106
+ def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
107
+ delta = batch_mean - mean
108
+ tot_count = count + batch_count
109
+
110
+ new_mean = mean + delta * batch_count / tot_count
111
+ m_a = var * count
112
+ m_b = batch_var * batch_count
113
+ M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
114
+ new_var = M2 / tot_count
115
+ new_count = tot_count
116
+
117
+ return new_mean, new_var, new_count
118
+
119
+
120
+ class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
121
+ def __init__(self, env: gym.Env, epsilon: float = 1e-8):
122
+ gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
123
+ gym.Wrapper.__init__(self, env)
124
+
125
+ try:
126
+ self.num_envs = self.get_wrapper_attr("num_envs")
127
+ self.is_vector_env = self.get_wrapper_attr("is_vector_env")
128
+ except AttributeError:
129
+ self.num_envs = 1
130
+ self.is_vector_env = False
131
+
132
+ if self.is_vector_env:
133
+ self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
134
+ else:
135
+ self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
136
+ self.epsilon = epsilon
137
+
138
+ self.enable = True
139
+ self.freeze = False
140
+
141
+ def step(self, action):
142
+ obs, rews, terminateds, truncateds, infos = self.env.step(action)
143
+ if self.is_vector_env:
144
+ obs = self.normalize(obs)
145
+ else:
146
+ obs = self.normalize(np.array([obs]))[0]
147
+ return obs, rews, terminateds, truncateds, infos
148
+
149
+ def reset(self, **kwargs):
150
+ obs, info = self.env.reset(**kwargs)
151
+
152
+ if self.is_vector_env:
153
+ return self.normalize(obs), info
154
+ else:
155
+ return self.normalize(np.array([obs]))[0], info
156
+
157
+ def normalize(self, obs):
158
+ if not self.freeze:
159
+ self.obs_rms.update(obs)
160
+ if self.enable:
161
+ return (obs - self.obs_rms.mean.cpu().numpy()) / np.sqrt(self.obs_rms.var.cpu().numpy() + self.epsilon)
162
+ return obs
163
+
164
+
165
+ class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
166
+ def __init__(
167
+ self,
168
+ env: gym.Env,
169
+ gamma: float = 0.99,
170
+ epsilon: float = 1e-8,
171
+ ):
172
+ gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
173
+ gym.Wrapper.__init__(self, env)
174
+
175
+ try:
176
+ self.num_envs = self.get_wrapper_attr("num_envs")
177
+ self.is_vector_env = self.get_wrapper_attr("is_vector_env")
178
+ except AttributeError:
179
+ self.num_envs = 1
180
+ self.is_vector_env = False
181
+
182
+ self.return_rms = RunningMeanStd(shape=())
183
+ self.returns = np.zeros(self.num_envs)
184
+ self.gamma = gamma
185
+ self.epsilon = epsilon
186
+
187
+ self.enable = True
188
+ self.freeze = False
189
+
190
+ def step(self, action):
191
+ obs, rews, terminateds, truncateds, infos = self.env.step(action)
192
+ if not self.is_vector_env:
193
+ rews = np.array([rews])
194
+ self.returns = self.returns * self.gamma * (1 - terminateds) + rews
195
+ rews = self.normalize(rews)
196
+ if not self.is_vector_env:
197
+ rews = rews[0]
198
+ return obs, rews, terminateds, truncateds, infos
199
+
200
+ def reset(self, **kwargs):
201
+ self.returns = np.zeros(self.num_envs)
202
+ return self.env.reset(**kwargs)
203
+
204
+ def normalize(self, rews):
205
+ if not self.freeze:
206
+ self.return_rms.update(self.returns)
207
+ if self.enable:
208
+ return rews / np.sqrt(self.return_rms.var.cpu().numpy() + self.epsilon)
209
+ return rews
210
+
211
+ def get_returns(self):
212
+ return self.returns
213
+
214
+
215
+ def evaluate(
216
+ model_path: str,
217
+ make_env: Callable,
218
+ env_id: str,
219
+ eval_episodes: int,
220
+ run_name: str,
221
+ Model: torch.nn.Module,
222
+ device: torch.device = torch.device("cpu"),
223
+ capture_video: bool = True,
224
+ ):
225
+ envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name)])
226
+ agent = Model(envs).to(device)
227
+ agent.load_state_dict(torch.load(model_path, map_location=device))
228
+ agent.eval()
229
+ envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, agent.obs_rms)])
230
+
231
+ obs, _ = envs.reset()
232
+ episodic_returns = []
233
+ while len(episodic_returns) < eval_episodes:
234
+ actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
235
+ next_obs, _, _, _, infos = envs.step(actions.cpu().numpy())
236
+ if "final_info" in infos:
237
+ for info in infos["final_info"]:
238
+ if "episode" not in info:
239
+ continue
240
+ print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
241
+ episodic_returns += [info["episode"]["r"]]
242
+ obs = next_obs
243
+
244
+ return episodic_returns
245
+
246
+
247
+ def make_env(env_id, idx, capture_video, run_name, gamma):
248
+ def thunk():
249
+ if capture_video:
250
+ env = gym.make(env_id, render_mode="rgb_array")
251
+ else:
252
+ env = gym.make(env_id)
253
+ env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
254
+ env = gym.wrappers.RecordEpisodeStatistics(env)
255
+ if capture_video:
256
+ if idx == 0:
257
+ env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
258
+ env = gym.wrappers.ClipAction(env)
259
+ env = NormalizeObservation(env)
260
+ env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
261
+ env = NormalizeReward(env, gamma=gamma)
262
+ env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
263
+ return env
264
+
265
+ return thunk
266
+
267
+
268
+ def make_eval_env(env_id, idx, capture_video, run_name, obs_rms=None):
269
+ def thunk():
270
+ if capture_video:
271
+ env = gym.make(env_id, render_mode="rgb_array")
272
+ else:
273
+ env = gym.make(env_id)
274
+ env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
275
+ env = gym.wrappers.RecordEpisodeStatistics(env)
276
+ if capture_video:
277
+ if idx == 0:
278
+ env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
279
+ env = gym.wrappers.ClipAction(env)
280
+ env = NormalizeObservation(env)
281
+ if obs_rms is not None:
282
+ env.obs_rms = copy.deepcopy(obs_rms)
283
+ env.freeze = True
284
+ env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
285
+ return env
286
+
287
+ return thunk
288
+
289
+
290
+ def get_rms(env):
291
+ obs_rms, return_rms = None, None
292
+ env_point = env
293
+ while hasattr(env_point, "env"):
294
+ if isinstance(env_point, NormalizeObservation):
295
+ obs_rms = copy.deepcopy(env_point.obs_rms)
296
+ break
297
+ env_point = env_point.env
298
+ else:
299
+ raise RuntimeError("can't find NormalizeObservation")
300
+
301
+ env_point = env
302
+ while hasattr(env_point, "env"):
303
+ if isinstance(env_point, NormalizeReward):
304
+ return_rms = copy.deepcopy(env_point.return_rms)
305
+ break
306
+ env_point = env_point.env
307
+ else:
308
+ raise RuntimeError("can't find NormalizeReward")
309
+
310
+ return obs_rms, return_rms
311
+
312
+
313
+ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
314
+ torch.nn.init.orthogonal_(layer.weight, std)
315
+ torch.nn.init.constant_(layer.bias, bias_const)
316
+ return layer
317
+
318
+
319
+ class Agent(nn.Module):
320
+ def __init__(self, envs):
321
+ super().__init__()
322
+ self.critic = nn.Sequential(
323
+ layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
324
+ nn.Tanh(),
325
+ layer_init(nn.Linear(64, 64)),
326
+ nn.Tanh(),
327
+ layer_init(nn.Linear(64, 1), std=1.0),
328
+ )
329
+ self.actor_mean = nn.Sequential(
330
+ layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
331
+ nn.Tanh(),
332
+ layer_init(nn.Linear(64, 64)),
333
+ nn.Tanh(),
334
+ layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
335
+ )
336
+ self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
337
+ self.obs_rms = RunningMeanStd(shape=envs.single_observation_space.shape)
338
+
339
+ def get_value(self, x):
340
+ return self.critic(x)
341
+
342
+ def get_action_and_value(self, x, action=None):
343
+ action_mean = self.actor_mean(x)
344
+ action_logstd = self.actor_logstd.expand_as(action_mean)
345
+ action_std = torch.exp(action_logstd)
346
+ probs = Normal(action_mean, action_std)
347
+ if action is None:
348
+ action = probs.sample()
349
+ return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
350
+
351
+
352
+ if __name__ == "__main__":
353
+ args = parse_args()
354
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
355
+ if args.track:
356
+ import wandb
357
+
358
+ wandb.init(
359
+ project=args.wandb_project_name,
360
+ entity=args.wandb_entity,
361
+ sync_tensorboard=True,
362
+ config=vars(args),
363
+ name=run_name,
364
+ monitor_gym=True,
365
+ save_code=True,
366
+ )
367
+ writer = SummaryWriter(f"runs/{run_name}")
368
+ writer.add_text(
369
+ "hyperparameters",
370
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
371
+ )
372
+
373
+ # TRY NOT TO MODIFY: seeding
374
+ random.seed(args.seed)
375
+ np.random.seed(args.seed)
376
+ torch.manual_seed(args.seed)
377
+ torch.backends.cudnn.deterministic = args.torch_deterministic
378
+
379
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
380
+
381
+ # env setup
382
+ envs = gym.vector.SyncVectorEnv(
383
+ [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
384
+ )
385
+ assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
386
+
387
+ agent = Agent(envs).to(device)
388
+ optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
389
+
390
+ # ALGO Logic: Storage setup
391
+ obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
392
+ actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
393
+ logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
394
+ rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
395
+ dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
396
+ values = torch.zeros((args.num_steps, args.num_envs)).to(device)
397
+
398
+ # TRY NOT TO MODIFY: start the game
399
+ global_step = 0
400
+ start_time = time.time()
401
+ next_obs, _ = envs.reset(seed=args.seed)
402
+ next_obs = torch.Tensor(next_obs).to(device)
403
+ next_done = torch.zeros(args.num_envs).to(device)
404
+ num_updates = args.total_timesteps // args.batch_size
405
+
406
+ for update in range(1, num_updates + 1):
407
+ # Annealing the rate if instructed to do so.
408
+ if args.anneal_lr:
409
+ frac = 1.0 - (update - 1.0) / num_updates
410
+ lrnow = frac * args.learning_rate
411
+ optimizer.param_groups[0]["lr"] = lrnow
412
+
413
+ for step in range(0, args.num_steps):
414
+ global_step += 1 * args.num_envs
415
+ obs[step] = next_obs
416
+ dones[step] = next_done
417
+
418
+ # ALGO LOGIC: action logic
419
+ with torch.no_grad():
420
+ action, logprob, _, value = agent.get_action_and_value(next_obs)
421
+ values[step] = value.flatten()
422
+ actions[step] = action
423
+ logprobs[step] = logprob
424
+
425
+ # TRY NOT TO MODIFY: execute the game and log data.
426
+ next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
427
+ done = np.logical_or(terminations, truncations)
428
+ rewards[step] = torch.tensor(reward).to(device).view(-1)
429
+ next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
430
+
431
+ # https://github.com/DLR-RM/stable-baselines3/pull/658
432
+ for idx, trunc in enumerate(truncations):
433
+ if trunc:
434
+ real_next_obs = infos["final_observation"][idx]
435
+ with torch.no_grad():
436
+ terminal_value = agent.get_value(torch.Tensor(real_next_obs).to(device)).reshape(1, -1)[0][0]
437
+ rewards[step][idx] += args.gamma * terminal_value
438
+
439
+ # Only print when at least 1 env is done
440
+ if "final_info" not in infos:
441
+ continue
442
+
443
+ for info in infos["final_info"]:
444
+ # Skip the envs that are not done
445
+ if info is None:
446
+ continue
447
+ print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
448
+ writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
449
+ writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
450
+
451
+ # bootstrap value if not done
452
+ with torch.no_grad():
453
+ next_value = agent.get_value(next_obs).reshape(1, -1)
454
+ advantages = torch.zeros_like(rewards).to(device)
455
+ lastgaelam = 0
456
+ for t in reversed(range(args.num_steps)):
457
+ if t == args.num_steps - 1:
458
+ nextnonterminal = 1.0 - next_done
459
+ nextvalues = next_value
460
+ else:
461
+ nextnonterminal = 1.0 - dones[t + 1]
462
+ nextvalues = values[t + 1]
463
+ delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
464
+ advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
465
+ returns = advantages + values
466
+
467
+ # flatten the batch
468
+ b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
469
+ b_logprobs = logprobs.reshape(-1)
470
+ b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
471
+ b_advantages = advantages.reshape(-1)
472
+ b_returns = returns.reshape(-1)
473
+ b_values = values.reshape(-1)
474
+
475
+ # Optimizing the policy and value network
476
+ b_inds = np.arange(args.batch_size)
477
+ clipfracs = []
478
+ for epoch in range(args.update_epochs):
479
+ np.random.shuffle(b_inds)
480
+ for start in range(0, args.batch_size, args.minibatch_size):
481
+ end = start + args.minibatch_size
482
+ mb_inds = b_inds[start:end]
483
+
484
+ _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
485
+ logratio = newlogprob - b_logprobs[mb_inds]
486
+ ratio = logratio.exp()
487
+
488
+ with torch.no_grad():
489
+ # calculate approx_kl http://joschu.net/blog/kl-approx.html
490
+ old_approx_kl = (-logratio).mean()
491
+ approx_kl = ((ratio - 1) - logratio).mean()
492
+ clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
493
+
494
+ mb_advantages = b_advantages[mb_inds]
495
+ if args.norm_adv:
496
+ mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
497
+
498
+ # Policy loss
499
+ pg_loss1 = -mb_advantages * ratio
500
+ pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
501
+ pg_loss = torch.max(pg_loss1, pg_loss2).mean()
502
+
503
+ # Value loss
504
+ newvalue = newvalue.view(-1)
505
+ if args.clip_vloss:
506
+ v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
507
+ v_clipped = b_values[mb_inds] + torch.clamp(
508
+ newvalue - b_values[mb_inds],
509
+ -args.clip_coef,
510
+ args.clip_coef,
511
+ )
512
+ v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
513
+ v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
514
+ v_loss = 0.5 * v_loss_max.mean()
515
+ else:
516
+ v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
517
+
518
+ entropy_loss = entropy.mean()
519
+ loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
520
+
521
+ optimizer.zero_grad()
522
+ loss.backward()
523
+ nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
524
+ optimizer.step()
525
+
526
+ if args.target_kl is not None:
527
+ if approx_kl > args.target_kl:
528
+ break
529
+
530
+ y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
531
+ var_y = np.var(y_true)
532
+ explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
533
+
534
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
535
+ writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
536
+ writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
537
+ writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
538
+ writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
539
+ writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
540
+ writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
541
+ writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
542
+ writer.add_scalar("losses/explained_variance", explained_var, global_step)
543
+ print("SPS:", int(global_step / (time.time() - start_time)))
544
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
545
+
546
+ if args.save_model:
547
+ agent.obs_rms = copy.deepcopy(get_rms(envs.envs[0])[0])
548
+ model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
549
+ torch.save(agent.state_dict(), model_path)
550
+ print(f"model saved to {model_path}")
551
+
552
+ episodic_returns = evaluate(
553
+ model_path,
554
+ make_eval_env,
555
+ args.env_id,
556
+ eval_episodes=10,
557
+ run_name=f"{run_name}-eval",
558
+ Model=Agent,
559
+ device=device,
560
+ )
561
+ for idx, episodic_return in enumerate(episodic_returns):
562
+ writer.add_scalar("eval/episodic_return", episodic_return, idx)
563
+
564
+ if args.upload_model:
565
+ from cleanrl_utils.huggingface import push_to_hub
566
+
567
+ repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
568
+ repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
569
+ push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
570
+
571
+ envs.close()
572
+ writer.close()
pyproject.toml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cleanrl"
3
+ version = "1.1.0"
4
+ description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
5
+ authors = ["Costa Huang <costa.huang@outlook.com>"]
6
+ packages = [
7
+ { include = "cleanrl" },
8
+ { include = "cleanrl_utils" },
9
+ ]
10
+ keywords = ["reinforcement", "machine", "learning", "research"]
11
+ license="MIT"
12
+ readme = "README.md"
13
+
14
+ [tool.poetry.dependencies]
15
+ python = ">=3.7.1,<3.11"
16
+ tensorboard = "^2.10.0"
17
+ wandb = "^0.13.11"
18
+ gym = "0.23.1"
19
+ torch = ">=1.12.1"
20
+ stable-baselines3 = "1.2.0"
21
+ gymnasium = ">=0.28.1"
22
+ moviepy = "^1.0.3"
23
+ pygame = "2.1.0"
24
+ huggingface-hub = "^0.11.1"
25
+ rich = "<12.0"
26
+ tenacity = "^8.2.2"
27
+
28
+ ale-py = {version = "0.7.4", optional = true}
29
+ AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2", optional = true}
30
+ opencv-python = {version = "^4.6.0.66", optional = true}
31
+ procgen = {version = "^0.10.7", optional = true}
32
+ pytest = {version = "^7.1.3", optional = true}
33
+ mujoco = {version = "<=2.3.3", optional = true}
34
+ imageio = {version = "^2.14.1", optional = true}
35
+ free-mujoco-py = {version = "^2.1.6", optional = true}
36
+ mkdocs-material = {version = "^8.4.3", optional = true}
37
+ markdown-include = {version = "^0.7.0", optional = true}
38
+ openrlbenchmark = {version = "^0.1.1b4", optional = true}
39
+ jax = {version = "^0.3.17", optional = true}
40
+ jaxlib = {version = "^0.3.15", optional = true}
41
+ flax = {version = "^0.6.0", optional = true}
42
+ optuna = {version = "^3.0.1", optional = true}
43
+ optuna-dashboard = {version = "^0.7.2", optional = true}
44
+ envpool = {version = "^0.6.4", optional = true}
45
+ PettingZoo = {version = "1.18.1", optional = true}
46
+ SuperSuit = {version = "3.4.0", optional = true}
47
+ multi-agent-ale-py = {version = "0.1.11", optional = true}
48
+ boto3 = {version = "^1.24.70", optional = true}
49
+ awscli = {version = "^1.25.71", optional = true}
50
+ shimmy = {version = ">=1.0.0", extras = ["dm-control"], optional = true}
51
+
52
+ [tool.poetry.group.dev.dependencies]
53
+ pre-commit = "^2.20.0"
54
+
55
+
56
+ [tool.poetry.group.isaacgym]
57
+ optional = true
58
+ [tool.poetry.group.isaacgym.dependencies]
59
+ isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry", python = ">=3.7.1,<3.10"}
60
+ isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
61
+
62
+
63
+ [build-system]
64
+ requires = ["poetry-core"]
65
+ build-backend = "poetry.core.masonry.api"
66
+
67
+ [tool.poetry.extras]
68
+ atari = ["ale-py", "AutoROM", "opencv-python"]
69
+ procgen = ["procgen"]
70
+ plot = ["pandas", "seaborn"]
71
+ pytest = ["pytest"]
72
+ mujoco = ["mujoco", "imageio"]
73
+ mujoco_py = ["free-mujoco-py"]
74
+ jax = ["jax", "jaxlib", "flax"]
75
+ docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"]
76
+ envpool = ["envpool"]
77
+ optuna = ["optuna", "optuna-dashboard"]
78
+ pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
79
+ cloud = ["boto3", "awscli"]
80
+ dm_control = ["shimmy", "mujoco"]
81
+
82
+ # dependencies for algorithm variant (useful when you want to run a specific algorithm)
83
+ dqn = []
84
+ dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
85
+ dqn_jax = ["jax", "jaxlib", "flax"]
86
+ dqn_atari_jax = [
87
+ "ale-py", "AutoROM", "opencv-python", # atari
88
+ "jax", "jaxlib", "flax" # jax
89
+ ]
90
+ c51 = []
91
+ c51_atari = ["ale-py", "AutoROM", "opencv-python"]
92
+ c51_jax = ["jax", "jaxlib", "flax"]
93
+ c51_atari_jax = [
94
+ "ale-py", "AutoROM", "opencv-python", # atari
95
+ "jax", "jaxlib", "flax" # jax
96
+ ]
97
+ ppo_atari_envpool_xla_jax_scan = [
98
+ "ale-py", "AutoROM", "opencv-python", # atari
99
+ "jax", "jaxlib", "flax", # jax
100
+ "envpool", # envpool
101
+ ]
102
+ qdagger_dqn_atari_impalacnn = [
103
+ "ale-py", "AutoROM", "opencv-python"
104
+ ]
105
+ qdagger_dqn_atari_jax_impalacnn = [
106
+ "ale-py", "AutoROM", "opencv-python", # atari
107
+ "jax", "jaxlib", "flax", # jax
108
+ ]
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2d4bb11d2f92e00b5038e07c51dc59085accee7fd4179d38117dfe84785be86
3
+ size 2020627
videos/HalfCheetah-v4__ppo_fix_continuous_action__5__1702935608-eval/rl-video-episode-0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caa2ebd338824bee0e92c013518e13ec90066c2bdfbd6f6528646a4276dde5df
3
+ size 2012444
videos/HalfCheetah-v4__ppo_fix_continuous_action__5__1702935608-eval/rl-video-episode-1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d8bcd5e574a03d7f72204c56700499f7a3d5ebfc004ff9a1e42d20adc41d757
3
+ size 1998690
videos/HalfCheetah-v4__ppo_fix_continuous_action__5__1702935608-eval/rl-video-episode-8.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2d4bb11d2f92e00b5038e07c51dc59085accee7fd4179d38117dfe84785be86
3
+ size 2020627