gabehubner commited on
Commit
ee1c253
1 Parent(s): f6f3371

add requirements

Browse files
__pycache__/ddpg.cpython-311.pyc CHANGED
Binary files a/__pycache__/ddpg.cpython-311.pyc and b/__pycache__/ddpg.cpython-311.pyc differ
 
__pycache__/train.cpython-311.pyc CHANGED
Binary files a/__pycache__/train.cpython-311.pyc and b/__pycache__/train.cpython-311.pyc differ
 
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import gradio as gr
2
+ from train import TrainingLoop
3
+
ddpg.py CHANGED
@@ -144,10 +144,6 @@ class ActorNetwork(nn.Module):
144
 
145
  def forward(self, state):
146
 
147
- try:
148
- assert state.shape == T.Size([8])
149
- except AssertionError:
150
- raise Exception(f"Wrong shape {state.shape=}")
151
 
152
  x = self.fc1(state)
153
  x = self.bn1(x)
@@ -182,7 +178,7 @@ class Agent(object):
182
 
183
  self.noise = OUActionNoise(mu=np.zeros(n_actions))
184
 
185
- self.attributions = None
186
  self.ig : IntegratedGradients = None
187
 
188
  self.update_network_parameters(tau=1)
@@ -195,7 +191,7 @@ class Agent(object):
195
 
196
  if self.ig is not None:
197
  attribution = self.ig.attribute(observation, baselines=baseline, n_steps=1)
198
- print('Attributions:', attribution)
199
 
200
 
201
  mu_prime = mu + T.tensor(self.noise(), dtype=T.float).to(self.actor.device)
 
144
 
145
  def forward(self, state):
146
 
 
 
 
 
147
 
148
  x = self.fc1(state)
149
  x = self.bn1(x)
 
178
 
179
  self.noise = OUActionNoise(mu=np.zeros(n_actions))
180
 
181
+ self.attributions = []
182
  self.ig : IntegratedGradients = None
183
 
184
  self.update_network_parameters(tau=1)
 
191
 
192
  if self.ig is not None:
193
  attribution = self.ig.attribute(observation, baselines=baseline, n_steps=1)
194
+ self.attributions.append(attribution)
195
 
196
 
197
  mu_prime = mu + T.tensor(self.noise(), dtype=T.float).to(self.actor.device)
main.py CHANGED
@@ -7,11 +7,11 @@ import argparse
7
  from train import TrainingLoop
8
  from captum.attr import (IntegratedGradients, LayerConductance, NeuronAttribution)
9
 
10
- training_loop = TrainingLoop(env_spec="LunarLander-v2", continuous=True, gravity=-10, render_mode=None)
11
  training_loop.create_agent()
12
 
13
  parser = argparse.ArgumentParser(description="Choose a function to run.")
14
- parser.add_argument("function", choices=["train", "load-trained", "attribute"], help="The function to run.")
15
 
16
  args = parser.parse_args()
17
 
@@ -20,4 +20,7 @@ if args.function == "train":
20
  elif args.function == "load-trained":
21
  training_loop.load_trained()
22
  elif args.function == "attribute":
23
- training_loop.explain_trained(option="2", num_iterations=10)
 
 
 
 
7
  from train import TrainingLoop
8
  from captum.attr import (IntegratedGradients, LayerConductance, NeuronAttribution)
9
 
10
+ training_loop = TrainingLoop(env_spec="LunarLander-v2", continuous=True, gravity=-10)
11
  training_loop.create_agent()
12
 
13
  parser = argparse.ArgumentParser(description="Choose a function to run.")
14
+ parser.add_argument("function", choices=["train", "load-trained", "attribute", "video"], help="The function to run.")
15
 
16
  args = parser.parse_args()
17
 
 
20
  elif args.function == "load-trained":
21
  training_loop.load_trained()
22
  elif args.function == "attribute":
23
+ frames, attributions = training_loop.explain_trained(option="2", num_iterations=10)
24
+ elif args.function == "video":
25
+ training_loop.render_video(20)
26
+
requirements.txt ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.8.5
4
+ aiosignal==1.3.1
5
+ alabaster==0.7.13
6
+ ale-py==0.8.1
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ anyio==3.7.1
10
+ appdirs==1.4.4
11
+ appnope==0.1.3
12
+ argon2-cffi==23.1.0
13
+ argon2-cffi-bindings==21.2.0
14
+ arrow==1.3.0
15
+ astatine==0.3.3
16
+ astor==0.8.1
17
+ astpretty==3.0.0
18
+ astroid==2.15.8
19
+ asttokens==2.4.0
20
+ astunparse==1.6.3
21
+ async-timeout==4.0.3
22
+ attrs==23.1.0
23
+ autoflake==1.7.8
24
+ AutoROM==0.4.2
25
+ AutoROM.accept-rom-license==0.6.1
26
+ Babel==2.13.0
27
+ backcall==0.2.0
28
+ bandit==1.7.5
29
+ beautifulsoup4==4.12.2
30
+ bitmath==1.3.3.1
31
+ black==23.10.0
32
+ bleach==6.1.0
33
+ box2d-py==2.3.5
34
+ Brotli==1.1.0
35
+ cachetools==5.3.1
36
+ captum==0.6.0
37
+ certifi==2023.7.22
38
+ cffi==1.16.0
39
+ chardet==4.0.0
40
+ charset-normalizer==3.3.0
41
+ chess==1.9.4
42
+ click==7.1.2
43
+ cloudpickle==1.3.0
44
+ cmake==3.27.7
45
+ cognitive-complexity==1.3.0
46
+ colorama==0.4.6
47
+ comm==0.1.4
48
+ contourpy==1.1.1
49
+ coverage==7.3.2
50
+ cycler==0.12.0
51
+ darglint==1.8.1
52
+ debugpy==1.8.0
53
+ decorator==4.4.2
54
+ defusedxml==0.7.1
55
+ deprecation==2.1.0
56
+ DI-engine==0.4.9
57
+ DI-toolkit==0.2.0
58
+ DI-treetensor==0.4.1
59
+ dill==0.3.7
60
+ distlib==0.3.7
61
+ dlint==0.14.1
62
+ doc8==1.1.1
63
+ docformatter==1.7.5
64
+ docker-pycreds==0.4.0
65
+ docutils==0.19
66
+ domdf-python-tools==3.6.1
67
+ easydict==1.9
68
+ entrypoints==0.4
69
+ enum-tools==0.11.0
70
+ eradicate==2.3.0
71
+ executing==2.0.0
72
+ Farama-Notifications==0.0.4
73
+ fastapi==0.104.0
74
+ fastjsonschema==2.18.1
75
+ ffmpeg==1.4
76
+ ffmpy==0.3.1
77
+ filelock==3.12.4
78
+ flake8==5.0.4
79
+ flake8-2020==1.8.1
80
+ flake8-aaa==0.16.0
81
+ flake8-annotations==3.0.1
82
+ flake8-annotations-complexity==0.0.8
83
+ flake8-annotations-coverage==0.0.6
84
+ flake8-bandit==4.1.1
85
+ flake8-black==0.3.6
86
+ flake8-blind-except==0.2.1
87
+ flake8-breakpoint==1.1.0
88
+ flake8-broken-line==0.6.0
89
+ flake8-bugbear==23.3.12
90
+ flake8-builtins==1.5.3
91
+ flake8-class-attributes-order==0.1.3
92
+ flake8-coding==1.3.2
93
+ flake8-cognitive-complexity==0.1.0
94
+ flake8-comments==0.1.2
95
+ flake8-comprehensions==3.14.0
96
+ flake8-debugger==4.1.2
97
+ flake8-django==1.4
98
+ flake8-docstrings==1.7.0
99
+ flake8-encodings==0.5.0.post1
100
+ flake8-eradicate==1.5.0
101
+ flake8-executable==2.1.3
102
+ flake8-expression-complexity==0.0.11
103
+ flake8-fastapi==0.7.0
104
+ flake8-fixme==1.1.1
105
+ flake8-functions==0.0.8
106
+ flake8-functions-names==0.4.0
107
+ flake8-future-annotations==0.0.5
108
+ flake8-helper==0.2.1
109
+ flake8-isort==6.1.0
110
+ flake8-literal==1.3.0
111
+ flake8-logging-format==0.9.0
112
+ flake8-markdown==0.5.0
113
+ flake8-mutable==1.2.0
114
+ flake8-no-pep420==2.7.0
115
+ flake8-noqa==1.3.2
116
+ flake8-pie==0.16.0
117
+ flake8-plugin-utils==1.3.3
118
+ flake8-pyi==22.11.0
119
+ flake8-pylint==0.2.1
120
+ flake8-pytest-style==1.7.2
121
+ flake8-quotes==3.3.2
122
+ flake8-rst-docstrings==0.3.0
123
+ flake8-secure-coding-standard==1.4.0
124
+ flake8-string-format==0.3.0
125
+ flake8-tidy-imports==4.10.0
126
+ flake8-typing-imports==1.15.0
127
+ flake8-use-fstring==1.4
128
+ flake8-use-pathlib==0.3.0
129
+ flake8-useless-assert==0.4.4
130
+ flake8-variables-names==0.0.6
131
+ flake8-warnings==0.4.0
132
+ flake8_simplify==0.21.0
133
+ Flask==1.1.4
134
+ Flask-Compress==1.14
135
+ flatbuffers==23.5.26
136
+ fonttools==4.43.1
137
+ fqdn==1.5.1
138
+ frozenlist==1.4.0
139
+ fsspec==2023.9.2
140
+ future==0.18.3
141
+ gast==0.5.4
142
+ gitdb==4.0.11
143
+ GitPython==3.1.40
144
+ glfw==2.6.2
145
+ google-auth==2.23.3
146
+ google-auth-oauthlib==1.0.0
147
+ google-pasta==0.2.0
148
+ gradio==4.7.1
149
+ gradio_client==0.7.0
150
+ graphviz==0.20.1
151
+ grpcio==1.59.0
152
+ gym==0.25.1
153
+ gym-notices==0.0.8
154
+ gymnasium==0.29.1
155
+ h11==0.14.0
156
+ h5py==3.10.0
157
+ hbutils==0.9.1
158
+ hickle==5.0.2
159
+ httpcore==1.0.2
160
+ httpx==0.25.2
161
+ huggingface-hub==0.19.4
162
+ hypothesis==6.88.1
163
+ hypothesmith==0.1.9
164
+ idna==3.4
165
+ imageio==2.31.5
166
+ imageio-ffmpeg==0.4.9
167
+ imagesize==1.4.1
168
+ importlib-metadata==6.8.0
169
+ importlib-resources==6.1.0
170
+ iniconfig==2.0.0
171
+ ipykernel==6.25.2
172
+ ipython==8.16.1
173
+ ipython-genutils==0.2.0
174
+ ipywidgets==8.1.1
175
+ isoduration==20.11.0
176
+ isort==5.12.0
177
+ itsdangerous==1.1.0
178
+ jedi==0.19.1
179
+ Jinja2==2.11.3
180
+ joblib==1.3.2
181
+ jsonpointer==2.4
182
+ jsonschema==4.19.2
183
+ jsonschema-specifications==2023.7.1
184
+ jupyter==1.0.0
185
+ jupyter-console==6.6.3
186
+ jupyter-events==0.9.0
187
+ jupyter_client==7.4.9
188
+ jupyter_core==5.3.2
189
+ jupyter_server==2.10.0
190
+ jupyter_server_terminals==0.4.4
191
+ jupyterlab-flake8==0.7.1
192
+ jupyterlab-pygments==0.2.2
193
+ jupyterlab-widgets==3.0.9
194
+ keras==2.14.0
195
+ keras-rl==0.4.2
196
+ kiwisolver==1.4.5
197
+ lark-parser==0.12.0
198
+ lazy-object-proxy==1.9.0
199
+ libclang==16.0.6
200
+ libcst==0.4.10
201
+ llvmlite==0.41.1
202
+ Markdown==3.5
203
+ markdown-it-py==3.0.0
204
+ MarkupSafe==2.0.1
205
+ matplotlib==3.8.0
206
+ matplotlib-inline==0.1.6
207
+ mccabe==0.7.0
208
+ mdurl==0.1.2
209
+ mediapy==1.1.9
210
+ mistune==0.8.4
211
+ ml-dtypes==0.2.0
212
+ moviepy==1.0.3
213
+ mpire==2.8.0
214
+ mpmath==1.3.0
215
+ mr-proper==0.0.7
216
+ mujoco==2.3.7
217
+ multidict==6.0.4
218
+ mypy-extensions==1.0.0
219
+ natsort==8.4.0
220
+ nbclassic==1.0.0
221
+ nbclient==0.5.13
222
+ nbconvert==6.4.5
223
+ nbformat==5.9.2
224
+ nest-asyncio==1.5.8
225
+ networkx==3.1
226
+ notebook==6.5.6
227
+ notebook_shim==0.2.3
228
+ numba==0.58.1
229
+ numpy==1.26.0
230
+ oauthlib==3.2.2
231
+ opencv-python==4.8.1.78
232
+ opt-einsum==3.3.0
233
+ orjson==3.9.10
234
+ overcooked-ai==1.1.0
235
+ overrides==7.4.0
236
+ packaging==23.2
237
+ pandas==2.1.1
238
+ pandas-vet==0.2.3
239
+ pandocfilters==1.5.0
240
+ parso==0.8.3
241
+ pathspec==0.11.2
242
+ pathtools==0.1.2
243
+ pbr==5.11.1
244
+ pep8-naming==0.13.3
245
+ pettingzoo==1.24.1
246
+ pexpect==4.8.0
247
+ pickleshare==0.7.5
248
+ Pillow==10.0.1
249
+ platformdirs==3.11.0
250
+ pluggy==1.3.0
251
+ proglog==0.1.10
252
+ prometheus-client==0.18.0
253
+ prompt-toolkit==3.0.39
254
+ protobuf==4.24.4
255
+ psutil==5.9.5
256
+ ptyprocess==0.7.0
257
+ pure-eval==0.2.2
258
+ pyasn1==0.5.0
259
+ pyasn1-modules==0.3.0
260
+ pybetter==0.4.1
261
+ pycln==2.3.0
262
+ pycodestyle==2.9.1
263
+ pycparser==2.21
264
+ pydantic==2.4.2
265
+ pydantic_core==2.10.1
266
+ pydocstyle==6.3.0
267
+ pydub==0.25.1
268
+ pyemojify==0.2.0
269
+ pyflakes==2.5.0
270
+ pygame==2.3.0
271
+ pyglet==2.0.0
272
+ Pygments==2.16.1
273
+ pylint==2.17.7
274
+ pynng==0.7.2
275
+ PyOpenGL==3.1.7
276
+ pyparsing==3.1.1
277
+ pyproject-api==1.6.1
278
+ pytest==7.4.3
279
+ pytest-cov==4.1.0
280
+ pytest-sugar==0.9.7
281
+ python-dateutil==2.8.2
282
+ python-dev-tools==2023.3.24
283
+ python-dotenv==1.0.0
284
+ python-json-logger==2.0.7
285
+ python-multipart==0.0.6
286
+ pytimeparse==1.1.8
287
+ pytz==2023.3.post1
288
+ pyupgrade==3.15.0
289
+ PyVirtualDisplay==3.0
290
+ PyYAML==6.0.1
291
+ pyzmq==24.0.1
292
+ qtconsole==5.5.0
293
+ QtPy==2.4.1
294
+ redis==5.0.1
295
+ referencing==0.30.2
296
+ removestar==1.5
297
+ requests==2.31.0
298
+ requests-oauthlib==1.3.1
299
+ responses==0.12.1
300
+ restructuredtext-lint==1.4.0
301
+ rfc3339-validator==0.1.4
302
+ rfc3986-validator==0.1.1
303
+ rich==13.6.0
304
+ rlcard==1.0.5
305
+ rpds-py==0.12.0
306
+ rsa==4.9
307
+ sb3-contrib==2.1.0
308
+ scikit-learn==1.3.1
309
+ scipy==1.11.3
310
+ seaborn==0.13.0
311
+ semantic-version==2.10.0
312
+ Send2Trash==1.8.2
313
+ sentry-sdk==1.32.0
314
+ setproctitle==1.3.3
315
+ shellingham==1.5.4
316
+ Shimmy==1.3.0
317
+ six==1.16.0
318
+ smmap==5.0.1
319
+ sniffio==1.3.0
320
+ snowballstemmer==2.2.0
321
+ sortedcontainers==2.4.0
322
+ soupsieve==2.5
323
+ Sphinx==6.2.1
324
+ sphinxcontrib-applehelp==1.0.7
325
+ sphinxcontrib-devhelp==1.0.5
326
+ sphinxcontrib-htmlhelp==2.0.4
327
+ sphinxcontrib-jsmath==1.0.1
328
+ sphinxcontrib-qthelp==1.0.6
329
+ sphinxcontrib-serializinghtml==1.1.9
330
+ ssort==0.11.6
331
+ stable-baselines3==2.1.0
332
+ stack-data==0.6.3
333
+ starlette==0.27.0
334
+ stdlib-list==0.9.0
335
+ stevedore==5.1.0
336
+ swig==4.1.1
337
+ sympy==1.12
338
+ tabulate==0.9.0
339
+ tensorboard==2.14.1
340
+ tensorboard-data-server==0.7.1
341
+ tensorboardX==2.6.2.2
342
+ tensordict==0.2.0
343
+ tensordict-nightly==2023.10.6
344
+ tensorflow==2.14.0
345
+ tensorflow-estimator==2.14.0
346
+ tensorflow-io-gcs-filesystem==0.34.0
347
+ tensorflow-macos==2.14.0
348
+ tensorflow-metal==1.1.0
349
+ termcolor==2.3.0
350
+ terminado==0.17.1
351
+ testpath==0.6.0
352
+ threadpoolctl==3.2.0
353
+ tinycss2==1.2.1
354
+ tokenize-rt==5.2.0
355
+ tomlkit==0.12.0
356
+ toolz==0.12.0
357
+ torch==2.1.0
358
+ torchrl @ git+https://github.com/pytorch/rl.git@bf264e0e24971fc05ec42b571de7b8df84043a51
359
+ torchsnapshot==0.1.0
360
+ torchvision==0.16.0
361
+ tornado==6.3.3
362
+ tox==4.11.3
363
+ tox-travis==0.12
364
+ tqdm==4.66.1
365
+ traitlets==5.11.2
366
+ treevalue==1.4.12
367
+ trueskill==0.4.5
368
+ typer==0.9.0
369
+ types-python-dateutil==2.8.19.14
370
+ typing-inspect==0.9.0
371
+ typing_extensions==4.8.0
372
+ tzdata==2023.3
373
+ Unidecode==1.3.7
374
+ untokenize==0.1.1
375
+ uri-template==1.3.0
376
+ urllib3==2.0.6
377
+ URLObject==2.4.3
378
+ uvicorn==0.24.0.post1
379
+ virtualenv==20.24.5
380
+ wandb==0.15.12
381
+ wcwidth==0.2.8
382
+ webcolors==1.13
383
+ webencodings==0.5.1
384
+ websocket-client==1.6.4
385
+ websockets==11.0.3
386
+ Werkzeug==1.0.1
387
+ widgetsnbextension==4.0.9
388
+ wrapt==1.14.1
389
+ yapf==0.29.0
390
+ yarl==1.9.2
391
+ yattag==1.15.1
392
+ zipp==3.17.0
tmp/ddpg/actor_ddpg CHANGED
Binary files a/tmp/ddpg/actor_ddpg and b/tmp/ddpg/actor_ddpg differ
 
tmp/ddpg/critic_ddpg CHANGED
Binary files a/tmp/ddpg/critic_ddpg and b/tmp/ddpg/critic_ddpg differ
 
tmp/ddpg/target_actor_ddpg CHANGED
Binary files a/tmp/ddpg/target_actor_ddpg and b/tmp/ddpg/target_actor_ddpg differ
 
tmp/ddpg/target_critic_ddpg CHANGED
Binary files a/tmp/ddpg/target_critic_ddpg and b/tmp/ddpg/target_critic_ddpg differ
 
train.py CHANGED
@@ -4,24 +4,23 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  import torch
6
  from captum.attr import (IntegratedGradients)
 
7
 
8
 
9
  class TrainingLoop:
10
  def __init__(self, env_spec, output_path='./output/', seed=0, **kwargs):
11
  assert env_spec in gym.envs.registry.keys()
12
 
13
- defaults = {
 
14
  "continuous": True,
15
  "gravity": -10.0,
16
  "render_mode": None
17
  }
18
 
19
- defaults.update(**kwargs)
20
 
21
- self.env = gym.make(
22
- env_spec,
23
- **defaults
24
- )
25
 
26
  torch.manual_seed(seed)
27
 
@@ -35,7 +34,13 @@ class TrainingLoop:
35
  def train(self):
36
  assert self.agent is not None
37
 
38
- self.agent.load_models()
 
 
 
 
 
 
39
 
40
  score_history = []
41
 
@@ -63,6 +68,12 @@ class TrainingLoop:
63
  def load_trained(self):
64
  assert self.agent is not None
65
 
 
 
 
 
 
 
66
  self.agent.load_models()
67
 
68
  score_history = []
@@ -84,12 +95,55 @@ class TrainingLoop:
84
 
85
  self.env.close()
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Model Explainability
88
 
89
  from captum.attr import (IntegratedGradients)
90
 
91
  def _collect_running_baseline_average(self, num_iterations: int) -> torch.Tensor:
92
  assert self.agent is not None
 
 
 
 
 
 
 
93
  print("--------- Collecting running baseline average ----------")
94
 
95
  self.agent.load_models()
@@ -129,6 +183,13 @@ class TrainingLoop:
129
 
130
  baseline = baseline_options[option]
131
 
 
 
 
 
 
 
 
132
  print("\n\n\n\n--------- Performing Attributions -----------")
133
 
134
  self.agent.load_models()
@@ -139,22 +200,32 @@ class TrainingLoop:
139
  self.agent.ig = ig
140
 
141
  score_history = []
 
142
 
143
  for i in range(50):
144
  done = False
145
  score = 0
146
  obs, _ = self.env.reset()
147
  while not done:
 
148
  act = self.agent.choose_action(observation=obs, baseline=baseline)
149
  new_state, reward, terminated, truncated, info = self.env.step(act)
150
  done = terminated or truncated
151
  score += reward
152
  obs = new_state
153
 
 
154
  score_history.append(score)
155
  print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))
156
 
157
  self.env.close()
158
 
159
- return self.agent.attributions
 
 
 
 
 
 
 
160
 
 
4
  import matplotlib.pyplot as plt
5
  import torch
6
  from captum.attr import (IntegratedGradients)
7
+ from gymnasium.wrappers import RecordVideo
8
 
9
 
10
  class TrainingLoop:
11
  def __init__(self, env_spec, output_path='./output/', seed=0, **kwargs):
12
  assert env_spec in gym.envs.registry.keys()
13
 
14
+ self.defaults = {
15
+ "id": env_spec,
16
  "continuous": True,
17
  "gravity": -10.0,
18
  "render_mode": None
19
  }
20
 
21
+ self.env = None
22
 
23
+ self.defaults.update(**kwargs)
 
 
 
24
 
25
  torch.manual_seed(seed)
26
 
 
34
  def train(self):
35
  assert self.agent is not None
36
 
37
+ self.defaults["render_mode"] = None
38
+
39
+ self.env = gym.make(
40
+ **self.defaults
41
+ )
42
+
43
+ # self.agent.load_models()
44
 
45
  score_history = []
46
 
 
68
  def load_trained(self):
69
  assert self.agent is not None
70
 
71
+ self.defaults["render_mode"] = None
72
+
73
+ self.env = gym.make(
74
+ **self.defaults
75
+ )
76
+
77
  self.agent.load_models()
78
 
79
  score_history = []
 
95
 
96
  self.env.close()
97
 
98
+ # Video Recording
99
+
100
+ # def render_video(self, episode_trigger=100):
101
+ # assert self.agent is not None
102
+
103
+ # self.defaults["render_mode"] = "rgb_array"
104
+ # self.env = gym.make(
105
+ # **self.defaults
106
+ # )
107
+
108
+ # episode_trigger_callable = lambda x: x % episode_trigger == 0
109
+
110
+ # self.env = RecordVideo(env=self.env, video_folder=self.output_path, name_prefix=f"{self.defaults['id']}-recording", episode_trigger=episode_trigger_callable, disable_logger=True)
111
+
112
+ # self.agent.load_models()
113
+
114
+ # score_history = []
115
+
116
+ # for i in range(200):
117
+ # done = False
118
+ # score = 0
119
+ # obs, _ = self.env.reset()
120
+ # while not done:
121
+ # act = self.agent.choose_action(observation=obs)
122
+ # new_state, reward, terminated, truncated, info = self.env.step(act)
123
+ # done = terminated or truncated
124
+ # score += reward
125
+ # obs = new_state
126
+
127
+
128
+ # score_history.append(score)
129
+ # print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))
130
+
131
+ # self.env.close()
132
+
133
+
134
  # Model Explainability
135
 
136
  from captum.attr import (IntegratedGradients)
137
 
138
  def _collect_running_baseline_average(self, num_iterations: int) -> torch.Tensor:
139
  assert self.agent is not None
140
+
141
+ self.defaults["render_mode"] = None
142
+
143
+ self.env = gym.make(
144
+ **self.defaults
145
+ )
146
+
147
  print("--------- Collecting running baseline average ----------")
148
 
149
  self.agent.load_models()
 
183
 
184
  baseline = baseline_options[option]
185
 
186
+ self.defaults["render_mode"] = "rgb_array"
187
+
188
+ self.env = gym.make(
189
+ **self.defaults
190
+ )
191
+
192
+
193
  print("\n\n\n\n--------- Performing Attributions -----------")
194
 
195
  self.agent.load_models()
 
200
  self.agent.ig = ig
201
 
202
  score_history = []
203
+ frames = []
204
 
205
  for i in range(50):
206
  done = False
207
  score = 0
208
  obs, _ = self.env.reset()
209
  while not done:
210
+ frames.append(self.env.render())
211
  act = self.agent.choose_action(observation=obs, baseline=baseline)
212
  new_state, reward, terminated, truncated, info = self.env.step(act)
213
  done = terminated or truncated
214
  score += reward
215
  obs = new_state
216
 
217
+
218
  score_history.append(score)
219
  print("episode", i, "score %.2f" % score, "100 game average %.2f" % np.mean(score_history[-100:]))
220
 
221
  self.env.close()
222
 
223
+ try:
224
+ assert len(frames) == len(self.agent.attributions)
225
+ except AssertionError:
226
+ print("Frames and agent attribution history are not the same shape!")
227
+ else:
228
+ pass
229
+
230
+ return (frames, self.agent.attributions)
231