ledmands commited on
Commit
650f88b
1 Parent(s): 08b1231

Added dqn_pacmanv5_run2.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/dqn_pacmanv5_run2.ipynb +318 -0
notebooks/dqn_pacmanv5_run2.ipynb ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%%capture\n",
10
+ "!pip install stable-baselines3[extra]\n",
11
+ "!pip install moviepy"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from stable_baselines3 import DQN\n",
21
+ "from stable_baselines3.common.monitor import Monitor\n",
22
+ "from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList\n",
23
+ "from stable_baselines3.common.logger import Video, HParam, TensorBoardOutputFormat\n",
24
+ "from stable_baselines3.common.evaluation import evaluate_policy\n",
25
+ "\n",
26
+ "from typing import Any, Dict\n",
27
+ "\n",
28
+ "import gymnasium as gym\n",
29
+ "import torch as th\n",
30
+ "import numpy as np\n",
31
+ "\n",
32
+ "# =====File names=====\n",
33
+ "MODEL_FILE_NAME = \"ALE-Pacman-v5\"\n",
34
+ "BUFFER_FILE_NAME = \"dqn_replay_buffer_pacman_v1\"\n",
35
+ "POLICY_FILE_NAME = \"dqn_policy_pacman_v1\"\n",
36
+ "\n",
37
+ "# =====Model Config=====\n",
38
+ "# Evaluate in tenths\n",
39
+ "EVAL_CALLBACK_FREQ = 150_000\n",
40
+ "# Record in quarters (the last one won't record, will have to do manually)\n",
41
+ "VIDEO_CALLBACK_FREQ = 375_000\n",
42
+ "FRAMESKIP = 4\n",
43
+ "NUM_TIMESTEPS = 1_500_000\n",
44
+ "\n",
45
+ "# =====Hyperparams=====\n",
46
+ "EXPLORATION_FRACTION = 0.3\n",
47
+ "# Buffer size needs to be less than about 60k in order to save it in a Kaggle instance\n",
48
+ "BUFFER_SIZE = 60_000\n",
49
+ "BATCH_SIZE = 64\n",
50
+ "LEARNING_STARTS = 50_000\n",
51
+ "LEARNING_RATE = 0.0002\n",
52
+ "GAMMA = 0.999\n",
53
+ "FINAL_EPSILON = 0.1\n",
54
+ "# Target Update Interval is set to 10k by default and looks like it is set to \n",
55
+ "# 4 in the Nature paper. This is a large discrepency and makes me wonder if it \n",
56
+ "# is something different or measured differently...\n",
57
+ "TARGET_UPDATE_INTERVAL = 1_000"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# VideoRecorderCallback\n",
67
+ "# The VideoRecorderCallback should record a video of the agent in the evaluation environment\n",
68
+ "# every render_freq timesteps. It will record one episode. It will also record one episode when\n",
69
+ "# the training has been completed\n",
70
+ "\n",
71
+ "class VideoRecorderCallback(BaseCallback):\n",
72
+ " def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):\n",
73
+ " \"\"\"\n",
74
+ " Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard.\n",
75
+ " :param eval_env: A gym environment from which the trajectory is recorded\n",
76
+ " :param render_freq: Render the agent's trajectory every eval_freq call of the callback.\n",
77
+ " :param n_eval_episodes: Number of episodes to render\n",
78
+ " :param deterministic: Whether to use deterministic or stochastic policy\n",
79
+ " \"\"\"\n",
80
+ " super().__init__()\n",
81
+ " self._eval_env = eval_env\n",
82
+ " self._render_freq = render_freq\n",
83
+ " self._n_eval_episodes = n_eval_episodes\n",
84
+ " self._deterministic = deterministic\n",
85
+ "\n",
86
+ " def _on_step(self) -> bool:\n",
87
+ " if self.n_calls % self._render_freq == 0:\n",
88
+ " screens = []\n",
89
+ "\n",
90
+ " def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:\n",
91
+ " \"\"\"\n",
92
+ " Renders the environment in its current state, recording the screen in the captured `screens` list\n",
93
+ " :param _locals: A dictionary containing all local variables of the callback's scope\n",
94
+ " :param _globals: A dictionary containing all global variables of the callback's scope\n",
95
+ " \"\"\"\n",
96
+ " screen = self._eval_env.render()\n",
97
+ " # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention\n",
98
+ " screens.append(screen.transpose(2, 0, 1))\n",
99
+ "\n",
100
+ " evaluate_policy(\n",
101
+ " self.model,\n",
102
+ " self._eval_env,\n",
103
+ " callback=grab_screens,\n",
104
+ " n_eval_episodes=self._n_eval_episodes,\n",
105
+ " deterministic=self._deterministic,\n",
106
+ " )\n",
107
+ " self.logger.record(\n",
108
+ " \"trajectory/video\",\n",
109
+ " Video(th.from_numpy(np.array([screens])), fps=60),\n",
110
+ " exclude=(\"stdout\", \"log\", \"json\", \"csv\"),\n",
111
+ " )\n",
112
+ " return True"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "# HParamCallback\n",
122
+ "# This should log the hyperparameters specified and map the metrics that are logged to \n",
123
+ "# the appropriate run.\n",
124
+ "class HParamCallback(BaseCallback):\n",
125
+ " \"\"\"\n",
126
+ " Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.\n",
127
+ " \"\"\" \n",
128
+ " def __init__(self):\n",
129
+ " super().__init__()\n",
130
+ " \n",
131
+ "\n",
132
+ " def _on_training_start(self) -> None:\n",
133
+ " \n",
134
+ " hparam_dict = {\n",
135
+ " \"algorithm\": self.model.__class__.__name__,\n",
136
+ " \"policy\": self.model.policy.__class__.__name__,\n",
137
+ " \"environment\": self.model.env.__class__.__name__,\n",
138
+ " \"buffer_size\": self.model.buffer_size,\n",
139
+ " \"batch_size\": self.model.batch_size,\n",
140
+ " \"tau\": self.model.tau,\n",
141
+ " \"gradient_steps\": self.model.gradient_steps,\n",
142
+ " \"target_update_interval\": self.model.target_update_interval,\n",
143
+ " \"exploration_fraction\": self.model.exploration_fraction,\n",
144
+ " \"exploration_initial_eps\": self.model.exploration_initial_eps,\n",
145
+ " \"exploration_final_eps\": self.model.exploration_final_eps,\n",
146
+ " \"max_grad_norm\": self.model.max_grad_norm,\n",
147
+ " \"tensorboard_log\": self.model.tensorboard_log,\n",
148
+ " \"seed\": self.model.seed, \n",
149
+ " \"learning rate\": self.model.learning_rate,\n",
150
+ " \"gamma\": self.model.gamma, \n",
151
+ " }\n",
152
+ " # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag\n",
153
+ " # Tensorbaord will find & display metrics from the `SCALARS` tab\n",
154
+ " metric_dict = {\n",
155
+ " \"eval/mean_ep_length\": 0,\n",
156
+ " \"eval/mean_reward\": 0,\n",
157
+ " \"rollout/ep_len_mean\": 0,\n",
158
+ " \"rollout/ep_rew_mean\": 0,\n",
159
+ " \"rollout/exploration_rate\": 0,\n",
160
+ " \"time/_episode_num\": 0,\n",
161
+ " \"time/fps\": 0,\n",
162
+ " \"time/total_timesteps\": 0,\n",
163
+ " \"train/learning_rate\": 0.0,\n",
164
+ " \"train/loss\": 0.0,\n",
165
+ " \"train/n_updates\": 0.0,\n",
166
+ " \"locals/rewards\": 0.0,\n",
167
+ " \"locals/infos_0_lives\": 0.0,\n",
168
+ " \"locals/num_collected_steps\": 0.0,\n",
169
+ " \"locals/num_collected_episodes\": 0.0\n",
170
+ " }\n",
171
+ " \n",
172
+ " self.logger.record(\n",
173
+ " \"hparams\",\n",
174
+ " HParam(hparam_dict, metric_dict),\n",
175
+ " exclude=(\"stdout\", \"log\", \"json\", \"csv\"),\n",
176
+ " )"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "# PlotTensorboardValuesCallback\n",
186
+ "# This callback should log values to tensorboard on every step. \n",
187
+ "# The self.logger class should plot a new scalar value when recording.\n",
188
+ "\n",
189
+ "class PlotTensorboardValuesCallback(BaseCallback):\n",
190
+ " \"\"\"\n",
191
+ " Custom callback for plotting additional values in tensorboard.\n",
192
+ " \"\"\"\n",
193
+ " def __init__(self, eval_env: gym.Env, train_env: gym.Env, model: DQN, verbose=0):\n",
194
+ " super().__init__(verbose)\n",
195
+ " self._eval_env = eval_env\n",
196
+ " self._train_env = train_env\n",
197
+ " self._model = model\n",
198
+ "\n",
199
+ " def _on_training_start(self) -> None:\n",
200
+ " output_formats = self.logger.output_formats\n",
201
+ " # Save reference to tensorboard formatter object\n",
202
+ " # note: the failure case (not formatter found) is not handled here, should be done with try/except.\n",
203
+ " try:\n",
204
+ " self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))\n",
205
+ " except:\n",
206
+ " print(\"Exception thrown in tb_formatter initialization.\") \n",
207
+ " \n",
208
+ " self.tb_formatter.writer.add_text(\"metadata/eval_env\", str(self._eval_env.metadata), self.num_timesteps)\n",
209
+ " self.tb_formatter.writer.flush()\n",
210
+ " self.tb_formatter.writer.add_text(\"metadata/train_env\", str(self._train_env.metadata), self.num_timesteps)\n",
211
+ " self.tb_formatter.writer.flush()\n",
212
+ " self.tb_formatter.writer.add_text(\"model/q_net\", str(self._model.q_net), self.num_timesteps)\n",
213
+ " self.tb_formatter.writer.flush()\n",
214
+ " self.tb_formatter.writer.add_text(\"model/q_net_target\", str(self._model.q_net_target), self.num_timesteps)\n",
215
+ " self.tb_formatter.writer.flush()\n",
216
+ "\n",
217
+ " def _on_step(self) -> bool:\n",
218
+ " self.logger.record(\"time/_episode_num\", self.model._episode_num, exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
219
+ " self.logger.record(\"train/n_updates\", self.model._n_updates, exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
220
+ " self.logger.record(\"locals/rewards\", self.locals[\"rewards\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
221
+ " self.logger.record(\"locals/infos_0_lives\", self.locals[\"infos\"][0][\"lives\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
222
+ " self.logger.record(\"locals/num_collected_steps\", self.locals[\"num_collected_steps\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
223
+ " self.logger.record(\"locals/num_collected_episodes\", self.locals[\"num_collected_episodes\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
224
+ " \n",
225
+ " return True\n",
226
+ " \n",
227
+ " def _on_training_end(self) -> None:\n",
228
+ " self.tb_formatter.writer.add_text(\"metadata/eval_env\", str(self._eval_env.metadata), self.num_timesteps)\n",
229
+ " self.tb_formatter.writer.flush()\n",
230
+ " self.tb_formatter.writer.add_text(\"metadata/train_env\", str(self._train_env.metadata), self.num_timesteps)\n",
231
+ " self.tb_formatter.writer.flush()\n",
232
+ " self.tb_formatter.writer.add_text(\"model/q_net\", str(self._model.q_net), self.num_timesteps)\n",
233
+ " self.tb_formatter.writer.flush()\n",
234
+ " self.tb_formatter.writer.add_text(\"model/q_net_target\", str(self._model.q_net_target), self.num_timesteps)\n",
235
+ " self.tb_formatter.writer.flush()"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "# make the training and evaluation environments\n",
245
+ "eval_env = Monitor(gym.make(\"ALE/Pacman-v5\", render_mode=\"rgb_array\", frameskip=FRAMESKIP))\n",
246
+ "train_env = gym.make(\"ALE/Pacman-v5\", render_mode=\"rgb_array\", frameskip=FRAMESKIP)\n",
247
+ "\n",
248
+ "# Make the model with specified hyperparams\n",
249
+ "model = DQN(\n",
250
+ " \"CnnPolicy\",\n",
251
+ " train_env,\n",
252
+ " verbose=1,\n",
253
+ " buffer_size=BUFFER_SIZE,\n",
254
+ " exploration_fraction = EXPLORATION_FRACTION,\n",
255
+ " batch_size=BATCH_SIZE,\n",
256
+ " exploration_final_eps=FINAL_EPSILON,\n",
257
+ " gamma=GAMMA,\n",
258
+ " learning_starts=LEARNING_STARTS,\n",
259
+ " learning_rate=LEARNING_RATE,\n",
260
+ " target_update_interval=TARGET_UPDATE_INTERVAL,\n",
261
+ " tensorboard_log=\"./\",\n",
262
+ " )"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "# Define the callbacks and put them in a list\n",
272
+ "eval_callback = EvalCallback(\n",
273
+ " eval_env,\n",
274
+ " best_model_save_path=\"./best_model/\",\n",
275
+ " log_path=\"./evals/\",\n",
276
+ " eval_freq=EVAL_CALLBACK_FREQ,\n",
277
+ " n_eval_episodes=10,\n",
278
+ " deterministic=True,\n",
279
+ " render=False)\n",
280
+ "\n",
281
+ "tbplot_callback = PlotTensorboardValuesCallback(eval_env=eval_env, train_env=train_env, model=model)\n",
282
+ "video_callback = VideoRecorderCallback(eval_env, render_freq=VIDEO_CALLBACK_FREQ)\n",
283
+ "hparam_callback = HParamCallback()\n",
284
+ "\n",
285
+ "callback_list = CallbackList([hparam_callback, eval_callback, video_callback, tbplot_callback])"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "# Train the model\n",
295
+ "model.learn(total_timesteps=NUM_TIMESTEPS, callback=callback_list, tb_log_name=\"./tb/\")"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "# Save the model, policy, and replay buffer for future loading and training\n",
305
+ "model.save(MODEL_FILE_NAME)\n",
306
+ "model.save_replay_buffer(BUFFER_FILE_NAME)\n",
307
+ "model.policy.save(POLICY_FILE_NAME)"
308
+ ]
309
+ }
310
+ ],
311
+ "metadata": {
312
+ "language_info": {
313
+ "name": "python"
314
+ }
315
+ },
316
+ "nbformat": 4,
317
+ "nbformat_minor": 2
318
+ }