Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .summary/0/events.out.tfevents.1718465455.koa03 +3 -0
- README.md +56 -0
- checkpoint_p0/best_000903328_1850015744_reward_37.590.pth +3 -0
- checkpoint_p0/checkpoint_000976608_2000093184.pth +3 -0
- checkpoint_p0/checkpoint_000976624_2000125952.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000025152_51511296.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000051136_104726528.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000077056_157810688.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000102912_210763776.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000128864_263913472.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000154736_316899328.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000180800_370278400.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000206272_422445056.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000232480_476119040.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000258176_528744448.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000284128_581894144.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000310224_635338752.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000335920_687964160.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000361712_740786176.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000387472_793542656.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000413072_845971456.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000438832_898727936.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000464800_951910400.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000490624_1004797952.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000516736_1058275328.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000542656_1111359488.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000568768_1164836864.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000594944_1218445312.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000620864_1271529472.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000646784_1324613632.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000672528_1377337344.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000698432_1430388736.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000724224_1483210752.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000750080_1536163840.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000775872_1588985856.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000801792_1642070016.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000827776_1695285248.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000853952_1748893696.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000879680_1801584640.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000912544_1868890112.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000961856_1969881088.pth +3 -0
- config.json +167 -0
- git.diff +712 -0
- replay.mp4 +3 -0
- sf_log.txt +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
replay.mp4 filter=lfs diff=lfs merge=lfs -text
|
.summary/0/events.out.tfevents.1718465455.koa03
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4cf539fd30b5412b5a0f264179fb9b4b56eb7155724db7142be3880eb7d1fd2a
|
3 |
+
size 19077526
|
README.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: sample-factory
|
3 |
+
tags:
|
4 |
+
- deep-reinforcement-learning
|
5 |
+
- reinforcement-learning
|
6 |
+
- sample-factory
|
7 |
+
model-index:
|
8 |
+
- name: APPO
|
9 |
+
results:
|
10 |
+
- task:
|
11 |
+
type: reinforcement-learning
|
12 |
+
name: reinforcement-learning
|
13 |
+
dataset:
|
14 |
+
name: atari_carnival
|
15 |
+
type: atari_carnival
|
16 |
+
metrics:
|
17 |
+
- type: mean_reward
|
18 |
+
value: 718.00 +/- 546.29
|
19 |
+
name: mean_reward
|
20 |
+
verified: false
|
21 |
+
---
|
22 |
+
|
23 |
+
A(n) **APPO** model trained on the **atari_carnival** environment.
|
24 |
+
|
25 |
+
This model was trained using Sample-Factory 2.0: https://github.com/alex-petrenko/sample-factory.
|
26 |
+
Documentation for how to use Sample-Factory can be found at https://www.samplefactory.dev/
|
27 |
+
|
28 |
+
|
29 |
+
## Downloading the model
|
30 |
+
|
31 |
+
After installing Sample-Factory, download the model with:
|
32 |
+
```
|
33 |
+
python -m sample_factory.huggingface.load_from_hub -r ksridhar/atari_2B_atari_carnival_1111
|
34 |
+
```
|
35 |
+
|
36 |
+
|
37 |
+
## Using the model
|
38 |
+
|
39 |
+
To run the model after download, use the `enjoy` script corresponding to this environment:
|
40 |
+
```
|
41 |
+
python -m <path.to.enjoy.module> --algo=APPO --env=atari_carnival --train_dir=./train_dir --experiment=atari_2B_atari_carnival_1111
|
42 |
+
```
|
43 |
+
|
44 |
+
|
45 |
+
You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
|
46 |
+
See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
|
47 |
+
|
48 |
+
## Training with this model
|
49 |
+
|
50 |
+
To continue training with this model, use the `train` script corresponding to this environment:
|
51 |
+
```
|
52 |
+
python -m <path.to.train.module> --algo=APPO --env=atari_carnival --train_dir=./train_dir --experiment=atari_2B_atari_carnival_1111 --restart_behavior=resume --train_for_env_steps=10000000000
|
53 |
+
```
|
54 |
+
|
55 |
+
Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
|
56 |
+
|
checkpoint_p0/best_000903328_1850015744_reward_37.590.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:583c645c02b1df4ae28d29773fce170c3acc68c8ca844712b213f3fb98f3444c
|
3 |
+
size 20722280
|
checkpoint_p0/checkpoint_000976608_2000093184.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7ec598d1afb79c7ba319ad7700b5e0b3bfac4eb59507862243e7edbc3932a26
|
3 |
+
size 20722628
|
checkpoint_p0/checkpoint_000976624_2000125952.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac2166f331ad435bd4e9b0a4e9b241bda41ed83eb4d3867bb19b9d6a59546eb2
|
3 |
+
size 20722628
|
checkpoint_p0/milestones/checkpoint_000025152_51511296.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38bd7e1c5c45514f4ac4abba9776c5603c70439bc85372a0b7fa2c806fdca3c3
|
3 |
+
size 20723568
|
checkpoint_p0/milestones/checkpoint_000051136_104726528.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3a0e06777147625cc609707d76f3a01562606c0e8b2bb2b09522c22ba66cdee
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000077056_157810688.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:72dc3cd28c88a5b9351900082ae0b6b7cbb8b6bc159f0b4fb29cc652b36cf152
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000102912_210763776.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab56d3ca1d032baf35b406178707b52a1589df04fcd84520aafe48657fc21c77
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000128864_263913472.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91eca10d3c06ca75ed1644b3d3428a964586718e6a7ecdb02a2db364b247313b
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000154736_316899328.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95339062f638e1dd6a57e48da59bdc2c38ed62ba51cd39ebc4245ab9a0a38302
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000180800_370278400.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af5eccbd1360bb3a6cbb0635ad8b65d005698b6753c733181fea7c63f589875a
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000206272_422445056.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a4a101a046102a298aec44ff87dc1928c8daf05bc2f6e675ccf7e0a5e6f4605
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000232480_476119040.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:043cd38aaa11b0387f5bf885d1efcfdebd44810ef4bcacac9a05b3737974e368
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000258176_528744448.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14a476c3fcabc8c5304cec02131ac1e4da542c73b9452c85fae69ac9d6503c2e
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000284128_581894144.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d22ebed92a65bb3c1d6009b1b90bdc76cb28dcede753f725a20ab9673d6447c
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000310224_635338752.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8ed8c981843651f96355f155c54023d3e60382c3c00f11300543d09cc9c5ae4
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000335920_687964160.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c728fc54979af87430f0b0fe8a49bd0d4165a36dae493b568af65ac6345546a1
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000361712_740786176.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c8e7949a5dccde8bfd4117b95f16f898a059313180746553734f975f7141b461
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000387472_793542656.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10424e0044a91b6978f7136bc3aabe6694ca6397e169b04bd1d9d0a31dd6810b
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000413072_845971456.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c8fde12d7fec9471143f99b08aff7dd3b1bb8b94dc39c46528fe5e5fca8149f
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000438832_898727936.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03ce10d9dc3205f2a8637f66510bc60112811e470db0c7ffaea60f4493f0fb23
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000464800_951910400.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b992be43ed2216501381c7f1771d398594b07aca2ee4c1ac4a7c1268f2b1f8ff
|
3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000490624_1004797952.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22e2b3c26570a835b7250e9f9073b3ea05af820efebdc29d70801f508e2c66ff
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000516736_1058275328.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4002c0ea6fa8292e4eddf06940877488e9979cde6f7b033d67c996696e79508
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000542656_1111359488.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3feda5847113a4d1ebe80c07b872eab92ffb3d7159ddc715b5329e6fffcfa3ef
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000568768_1164836864.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bef90e1168a3362b876e6c094f61793b550d3b43e16aedfda3fdda0f3b0c39cc
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000594944_1218445312.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e013d933cf86f6d867c28406c84838abd404ac1bbd52cc793549bd30aa947299
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000620864_1271529472.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ee5a3035f2c32de7c4b516029aab38df5691ce6c6b4e77942b7941c4873b2c0
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000646784_1324613632.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f770eee4a48913ea77863874f0f88e003ada4cf921c7249534d04173084ebf8
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000672528_1377337344.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fd959a59678ea98412b1bc346306fa2ca42e09b79ff87905c1e6431784d425d
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000698432_1430388736.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e0dcbf805a451848487f991a8b2424d5d8b25eef038db798cbd40960e6cd7e7
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000724224_1483210752.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:928747fe2c439d68c3394654c95aa39f49eb6dce8c965c2bfc1c59b8d3034318
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000750080_1536163840.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ba8292cb5946cb20fd3b718c2611324a8dc4e15794547158652c6e48b9a2673
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000775872_1588985856.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ec3267e3bfaad4e2d87e2c133587f22a7b44b614189c89d658ca636c6efa501
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000801792_1642070016.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c378961005d1914076bd32afa1f8b5521336afa9f986fa99255ae53e3fd2c0fe
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000827776_1695285248.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f83480c47c7d43709335c568acfe32dd9b24618ed60cc3896d5b40735abaa06f
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000853952_1748893696.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b176b68602ad39521c011a78be249616800f591673cddb98467057385ad7912
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000879680_1801584640.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:40e7eb08c65a0d0665a286fb89d1f3c0bb94f3de6cbf0ad59b81ce70cb85bb15
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000912544_1868890112.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71dc0efa7a4baa3f89466c4ddf49aa9dabe101901eecc5f08704caf5332c3d9b
|
3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000961856_1969881088.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:135bcc4724c00196d2e8e784c615a74ee5a15b72a4044003dbda909b52e8a281
|
3 |
+
size 20723684
|
config.json
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"help": false,
|
3 |
+
"algo": "APPO",
|
4 |
+
"env": "atari_carnival",
|
5 |
+
"experiment": "atari_2B_atari_carnival_1111",
|
6 |
+
"train_dir": "train_dir",
|
7 |
+
"restart_behavior": "resume",
|
8 |
+
"device": "gpu",
|
9 |
+
"seed": 1111,
|
10 |
+
"num_policies": 1,
|
11 |
+
"async_rl": true,
|
12 |
+
"serial_mode": false,
|
13 |
+
"batched_sampling": true,
|
14 |
+
"num_batches_to_accumulate": 2,
|
15 |
+
"worker_num_splits": 1,
|
16 |
+
"policy_workers_per_policy": 1,
|
17 |
+
"max_policy_lag": 1000,
|
18 |
+
"num_workers": 4,
|
19 |
+
"num_envs_per_worker": 1,
|
20 |
+
"batch_size": 1024,
|
21 |
+
"num_batches_per_epoch": 8,
|
22 |
+
"num_epochs": 2,
|
23 |
+
"rollout": 64,
|
24 |
+
"recurrence": 1,
|
25 |
+
"shuffle_minibatches": false,
|
26 |
+
"gamma": 0.99,
|
27 |
+
"reward_scale": 1.0,
|
28 |
+
"reward_clip": 1000.0,
|
29 |
+
"value_bootstrap": false,
|
30 |
+
"normalize_returns": true,
|
31 |
+
"exploration_loss_coeff": 0.0004677351413,
|
32 |
+
"value_loss_coeff": 0.5,
|
33 |
+
"kl_loss_coeff": 0.0,
|
34 |
+
"exploration_loss": "entropy",
|
35 |
+
"gae_lambda": 0.95,
|
36 |
+
"ppo_clip_ratio": 0.1,
|
37 |
+
"ppo_clip_value": 1.0,
|
38 |
+
"with_vtrace": false,
|
39 |
+
"vtrace_rho": 1.0,
|
40 |
+
"vtrace_c": 1.0,
|
41 |
+
"optimizer": "adam",
|
42 |
+
"adam_eps": 1e-05,
|
43 |
+
"adam_beta1": 0.9,
|
44 |
+
"adam_beta2": 0.999,
|
45 |
+
"max_grad_norm": 0.0,
|
46 |
+
"learning_rate": 0.0003033891184,
|
47 |
+
"lr_schedule": "linear_decay",
|
48 |
+
"lr_schedule_kl_threshold": 0.008,
|
49 |
+
"lr_adaptive_min": 1e-06,
|
50 |
+
"lr_adaptive_max": 0.01,
|
51 |
+
"obs_subtract_mean": 0.0,
|
52 |
+
"obs_scale": 255.0,
|
53 |
+
"normalize_input": true,
|
54 |
+
"normalize_input_keys": [
|
55 |
+
"obs"
|
56 |
+
],
|
57 |
+
"decorrelate_experience_max_seconds": 1,
|
58 |
+
"decorrelate_envs_on_one_worker": true,
|
59 |
+
"actor_worker_gpus": [],
|
60 |
+
"set_workers_cpu_affinity": true,
|
61 |
+
"force_envs_single_thread": false,
|
62 |
+
"default_niceness": 0,
|
63 |
+
"log_to_file": true,
|
64 |
+
"experiment_summaries_interval": 3,
|
65 |
+
"flush_summaries_interval": 30,
|
66 |
+
"stats_avg": 100,
|
67 |
+
"summaries_use_frameskip": true,
|
68 |
+
"heartbeat_interval": 20,
|
69 |
+
"heartbeat_reporting_interval": 180,
|
70 |
+
"train_for_env_steps": 2000000000,
|
71 |
+
"train_for_seconds": 3600000,
|
72 |
+
"save_every_sec": 120,
|
73 |
+
"keep_checkpoints": 2,
|
74 |
+
"load_checkpoint_kind": "latest",
|
75 |
+
"save_milestones_sec": 1200,
|
76 |
+
"save_best_every_sec": 5,
|
77 |
+
"save_best_metric": "reward",
|
78 |
+
"save_best_after": 100000,
|
79 |
+
"benchmark": false,
|
80 |
+
"encoder_mlp_layers": [
|
81 |
+
512,
|
82 |
+
512
|
83 |
+
],
|
84 |
+
"encoder_conv_architecture": "convnet_atari",
|
85 |
+
"encoder_conv_mlp_layers": [
|
86 |
+
512
|
87 |
+
],
|
88 |
+
"use_rnn": false,
|
89 |
+
"rnn_size": 512,
|
90 |
+
"rnn_type": "gru",
|
91 |
+
"rnn_num_layers": 1,
|
92 |
+
"decoder_mlp_layers": [],
|
93 |
+
"nonlinearity": "relu",
|
94 |
+
"policy_initialization": "orthogonal",
|
95 |
+
"policy_init_gain": 1.0,
|
96 |
+
"actor_critic_share_weights": true,
|
97 |
+
"adaptive_stddev": false,
|
98 |
+
"continuous_tanh_scale": 0.0,
|
99 |
+
"initial_stddev": 1.0,
|
100 |
+
"use_env_info_cache": false,
|
101 |
+
"env_gpu_actions": false,
|
102 |
+
"env_gpu_observations": true,
|
103 |
+
"env_frameskip": 4,
|
104 |
+
"env_framestack": 4,
|
105 |
+
"pixel_format": "CHW",
|
106 |
+
"use_record_episode_statistics": true,
|
107 |
+
"episode_counter": false,
|
108 |
+
"with_wandb": false,
|
109 |
+
"wandb_user": null,
|
110 |
+
"wandb_project": "sample_factory",
|
111 |
+
"wandb_group": null,
|
112 |
+
"wandb_job_type": "SF",
|
113 |
+
"wandb_tags": [],
|
114 |
+
"with_pbt": false,
|
115 |
+
"pbt_mix_policies_in_one_env": true,
|
116 |
+
"pbt_period_env_steps": 5000000,
|
117 |
+
"pbt_start_mutation": 20000000,
|
118 |
+
"pbt_replace_fraction": 0.3,
|
119 |
+
"pbt_mutation_rate": 0.15,
|
120 |
+
"pbt_replace_reward_gap": 0.1,
|
121 |
+
"pbt_replace_reward_gap_absolute": 1e-06,
|
122 |
+
"pbt_optimize_gamma": false,
|
123 |
+
"pbt_target_objective": "true_objective",
|
124 |
+
"pbt_perturb_min": 1.1,
|
125 |
+
"pbt_perturb_max": 1.5,
|
126 |
+
"env_agents": 512,
|
127 |
+
"command_line": "--seed=1111 --experiment=atari_2B_atari_carnival_1111 --env=atari_carnival --train_for_seconds=3600000 --algo=APPO --gamma=0.99 --num_workers=4 --num_envs_per_worker=1 --worker_num_splits=1 --env_agents=512 --benchmark=False --max_grad_norm=0.0 --decorrelate_experience_max_seconds=1 --encoder_conv_architecture=convnet_atari --encoder_conv_mlp_layers 512 --nonlinearity=relu --num_policies=1 --normalize_input=True --normalize_input_keys obs --normalize_returns=True --async_rl=True --batched_sampling=True --train_for_env_steps=2000000000 --save_milestones_sec=1200 --train_dir train_dir --rollout 64 --exploration_loss_coeff 0.0004677351413 --num_epochs 2 --batch_size 1024 --num_batches_per_epoch 8 --learning_rate 0.0003033891184",
|
128 |
+
"cli_args": {
|
129 |
+
"algo": "APPO",
|
130 |
+
"env": "atari_carnival",
|
131 |
+
"experiment": "atari_2B_atari_carnival_1111",
|
132 |
+
"train_dir": "train_dir",
|
133 |
+
"seed": 1111,
|
134 |
+
"num_policies": 1,
|
135 |
+
"async_rl": true,
|
136 |
+
"batched_sampling": true,
|
137 |
+
"worker_num_splits": 1,
|
138 |
+
"num_workers": 4,
|
139 |
+
"num_envs_per_worker": 1,
|
140 |
+
"batch_size": 1024,
|
141 |
+
"num_batches_per_epoch": 8,
|
142 |
+
"num_epochs": 2,
|
143 |
+
"rollout": 64,
|
144 |
+
"gamma": 0.99,
|
145 |
+
"normalize_returns": true,
|
146 |
+
"exploration_loss_coeff": 0.0004677351413,
|
147 |
+
"max_grad_norm": 0.0,
|
148 |
+
"learning_rate": 0.0003033891184,
|
149 |
+
"normalize_input": true,
|
150 |
+
"normalize_input_keys": [
|
151 |
+
"obs"
|
152 |
+
],
|
153 |
+
"decorrelate_experience_max_seconds": 1,
|
154 |
+
"train_for_env_steps": 2000000000,
|
155 |
+
"train_for_seconds": 3600000,
|
156 |
+
"save_milestones_sec": 1200,
|
157 |
+
"benchmark": false,
|
158 |
+
"encoder_conv_architecture": "convnet_atari",
|
159 |
+
"encoder_conv_mlp_layers": [
|
160 |
+
512
|
161 |
+
],
|
162 |
+
"nonlinearity": "relu",
|
163 |
+
"env_agents": 512
|
164 |
+
},
|
165 |
+
"git_hash": "e259c57b8c7aa9c7f541e9efd1316f8e6f97a6db",
|
166 |
+
"git_repo_name": "https://github.com/kaustubhsridhar/jat_regent.git"
|
167 |
+
}
|
git.diff
ADDED
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/README.md b/README.md
|
2 |
+
index e51a12b..a6e1ca1 100644
|
3 |
+
--- a/README.md
|
4 |
+
+++ b/README.md
|
5 |
+
@@ -21,6 +21,21 @@ conda activate jat
|
6 |
+
pip install -e .[dev]
|
7 |
+
```
|
8 |
+
|
9 |
+
+## REGENT fork of sample-factory: Installation
|
10 |
+
+Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork:
|
11 |
+
+```shell
|
12 |
+
+git clone https://github.com/kaustubhsridhar/sample-factory.git
|
13 |
+
+cd sample-factory
|
14 |
+
+pip install -e .[dev,mujoco,atari,envpool,vizdoom]
|
15 |
+
+```
|
16 |
+
+
|
17 |
+
+# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets
|
18 |
+
+Train policies using envpool's atari:
|
19 |
+
+```shell
|
20 |
+
+bash scripts_sample-factory/train_unseen_atari.sh
|
21 |
+
+```
|
22 |
+
+Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2).
|
23 |
+
+
|
24 |
+
## PREV Installation
|
25 |
+
|
26 |
+
To get started with JAT, follow these steps:
|
27 |
+
@@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS
|
28 |
+
```
|
29 |
+
|
30 |
+
### REGENT Analyze data
|
31 |
+
+Necessary:
|
32 |
+
```shell
|
33 |
+
-python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
|
34 |
+
-
|
35 |
+
python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt &
|
36 |
+
+```
|
37 |
+
|
38 |
+
+Already ran and output dict in code:
|
39 |
+
+```shell
|
40 |
+
python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt &
|
41 |
+
+
|
42 |
+
+python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt &
|
43 |
+
+```
|
44 |
+
+
|
45 |
+
+Optional:
|
46 |
+
+```shell
|
47 |
+
+python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
|
48 |
+
```
|
49 |
+
|
50 |
+
## PREV Dataset
|
51 |
+
diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py
|
52 |
+
deleted file mode 100644
|
53 |
+
index b2bd8bf..0000000
|
54 |
+
--- a/jat_regent/RandP.py
|
55 |
+
+++ /dev/null
|
56 |
+
@@ -1,38 +0,0 @@
|
57 |
+
-import warnings
|
58 |
+
-from dataclasses import dataclass
|
59 |
+
-from typing import List, Optional, Tuple, Union
|
60 |
+
-
|
61 |
+
-import numpy as np
|
62 |
+
-import torch
|
63 |
+
-import torch.nn.functional as F
|
64 |
+
-from gymnasium import spaces
|
65 |
+
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
|
66 |
+
-from transformers import GPTNeoModel, GPTNeoPreTrainedModel
|
67 |
+
-from transformers.modeling_outputs import ModelOutput
|
68 |
+
-from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
|
69 |
+
-
|
70 |
+
-from jat.configuration_jat import JatConfig
|
71 |
+
-from jat.processing_jat import JatProcessor
|
72 |
+
-
|
73 |
+
-
|
74 |
+
-class RandP():
|
75 |
+
- def __init__(self, dataset) -> None:
|
76 |
+
- self.steps = 0
|
77 |
+
- # create an index for retrieval in vector obs envs (OR) collect all images in Atari
|
78 |
+
-
|
79 |
+
- def reset_rl(self):
|
80 |
+
- self.steps = 0
|
81 |
+
-
|
82 |
+
- def get_next_action(
|
83 |
+
- self,
|
84 |
+
- processor: JatProcessor,
|
85 |
+
- continuous_observation: Optional[List[float]] = None,
|
86 |
+
- discrete_observation: Optional[List[int]] = None,
|
87 |
+
- text_observation: Optional[str] = None,
|
88 |
+
- image_observation: Optional[np.ndarray] = None,
|
89 |
+
- action_space: Union[spaces.Box, spaces.Discrete] = None,
|
90 |
+
- reward: Optional[float] = None,
|
91 |
+
- deterministic: bool = False,
|
92 |
+
- context_window: Optional[int] = None,
|
93 |
+
- ):
|
94 |
+
- pass
|
95 |
+
|
96 |
+
diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py
|
97 |
+
deleted file mode 100644
|
98 |
+
index e69de29..0000000
|
99 |
+
diff --git a/jat_regent/utils.py b/jat_regent/utils.py
|
100 |
+
index 56bfb44..36f6cca 100644
|
101 |
+
--- a/jat_regent/utils.py
|
102 |
+
+++ b/jat_regent/utils.py
|
103 |
+
@@ -8,23 +8,35 @@ from tqdm import tqdm
|
104 |
+
from autofaiss import build_index
|
105 |
+
|
106 |
+
|
107 |
+
+UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11
|
108 |
+
+
|
109 |
+
+}
|
110 |
+
+
|
111 |
+
def myprint(str):
|
112 |
+
- # check if first character of string is a newline character
|
113 |
+
- if str[0] == '\n':
|
114 |
+
- str_without_newline = str[1:]
|
115 |
+
+ # check if first characters of string are newline character
|
116 |
+
+ num_newlines = 0
|
117 |
+
+ while str[num_newlines] == '\n':
|
118 |
+
print()
|
119 |
+
- else:
|
120 |
+
- str_without_newline = str
|
121 |
+
+ num_newlines += 1
|
122 |
+
+ str_without_newline = str[num_newlines:]
|
123 |
+
print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}')
|
124 |
+
|
125 |
+
def is_png_img(item):
|
126 |
+
return isinstance(item, PngImagePlugin.PngImageFile)
|
127 |
+
|
128 |
+
+def get_last_row_for_1M_states(task):
|
129 |
+
+ last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101}
|
130 |
+
+ return last_row_idx[task]
|
131 |
+
+
|
132 |
+
+def get_last_row_for_100k_states(task):
|
133 |
+
+ last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407}
|
134 |
+
+ return last_row_idx[task]
|
135 |
+
+
|
136 |
+
def get_obs_dim(task):
|
137 |
+
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
|
138 |
+
|
139 |
+
all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17}
|
140 |
+
- return all_obs_dims[task]
|
141 |
+
+ return (all_obs_dims[task],)
|
142 |
+
|
143 |
+
def get_act_dim(task):
|
144 |
+
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
|
145 |
+
@@ -36,141 +48,188 @@ def get_act_dim(task):
|
146 |
+
elif task.startswith("mujoco"):
|
147 |
+
all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6}
|
148 |
+
return all_act_dims[task]
|
149 |
+
-
|
150 |
+
-def process_row_atari(attn_mask, row_of_obs, task):
|
151 |
+
- """
|
152 |
+
- Example for selection with bools:
|
153 |
+
- >>> a = np.array([0,1,2,3,4,5])
|
154 |
+
- >>> b = np.array([1,0,0,0,0,1]).astype(bool)
|
155 |
+
- >>> a[b]
|
156 |
+
- array([0, 5])
|
157 |
+
- """
|
158 |
+
- attn_mask = np.array(attn_mask).astype(bool)
|
159 |
+
|
160 |
+
- row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
|
161 |
+
- row_of_obs = row_of_obs[attn_mask]
|
162 |
+
+def get_task_info(task):
|
163 |
+
+ rew_key = 'rewards'
|
164 |
+
+ attn_key = 'attention_mask'
|
165 |
+
+ if task.startswith("atari"):
|
166 |
+
+ obs_key = 'image_observations'
|
167 |
+
+ act_key = 'discrete_actions'
|
168 |
+
+ B = 32 # half of 54
|
169 |
+
+ obs_dim = (3, 4*84, 84)
|
170 |
+
+ elif task.startswith("babyai"):
|
171 |
+
+ obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
|
172 |
+
+ act_key = 'discrete_actions'
|
173 |
+
+ B = 256 # half of 512
|
174 |
+
+ obs_dim = get_obs_dim(task)
|
175 |
+
+ elif task.startswith("metaworld") or task.startswith("mujoco"):
|
176 |
+
+ obs_key = 'continuous_observations'
|
177 |
+
+ act_key = 'continuous_actions'
|
178 |
+
+ B = 256
|
179 |
+
+ obs_dim = get_obs_dim(task)
|
180 |
+
+
|
181 |
+
+ return rew_key, attn_key, obs_key, act_key, B, obs_dim
|
182 |
+
+
|
183 |
+
+def process_row_of_obs_atari_full_without_mask(row_of_obs):
|
184 |
+
+
|
185 |
+
+ if not isinstance(row_of_obs, torch.Tensor):
|
186 |
+
+ row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
|
187 |
+
row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1]
|
188 |
+
- assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84)
|
189 |
+
+ assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84)
|
190 |
+
row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84)
|
191 |
+
- row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side
|
192 |
+
+ row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side
|
193 |
+
row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels
|
194 |
+
- assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
|
195 |
+
-
|
196 |
+
- return attn_mask, row_of_obs
|
197 |
+
+ assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
|
198 |
+
+
|
199 |
+
+ return row_of_obs
|
200 |
+
|
201 |
+
-def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False):
|
202 |
+
- attn_mask = np.array(attn_mask).astype(bool)
|
203 |
+
+def collect_all_atari_data(dataset, all_row_idxs=None):
|
204 |
+
+ if all_row_idxs is None:
|
205 |
+
+ all_row_idxs = list(range(len(dataset['train'])))
|
206 |
+
|
207 |
+
- row_of_obs = np.array(row_of_obs)
|
208 |
+
- if not return_numpy:
|
209 |
+
- row_of_obs = torch.tensor(row_of_obs)
|
210 |
+
- row_of_obs = row_of_obs[attn_mask]
|
211 |
+
- assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task))
|
212 |
+
-
|
213 |
+
- return attn_mask, row_of_obs
|
214 |
+
-
|
215 |
+
-def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84)
|
216 |
+
- dataset, # to retrieve from
|
217 |
+
- all_rows_to_consider, # rows to consider
|
218 |
+
- num_to_retrieve, # top-k
|
219 |
+
+ all_rows_of_obs = []
|
220 |
+
+ all_attn_masks = []
|
221 |
+
+ for row_idx in tqdm(all_row_idxs):
|
222 |
+
+ datarow = dataset['train'][row_idx]
|
223 |
+
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations'])
|
224 |
+
+ attn_mask = np.array(datarow['attention_mask']).astype(bool)
|
225 |
+
+ all_rows_of_obs.append(row_of_obs) # appending tensor
|
226 |
+
+ all_attn_masks.append(attn_mask) # appending np array
|
227 |
+
+ all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors
|
228 |
+
+ all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays
|
229 |
+
+ assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and
|
230 |
+
+ all_attn_masks.shape == (len(all_row_idxs), 32))
|
231 |
+
+ return all_attn_masks, all_rows_of_obs
|
232 |
+
+
|
233 |
+
+def collect_all_data(dataset, task, obs_key):
|
234 |
+
+ last_row_idx = get_last_row_for_100k_states(task)
|
235 |
+
+ all_row_idxs = list(range(last_row_idx))
|
236 |
+
+ if task.startswith("atari"):
|
237 |
+
+ myprint("Collecting all Atari images and Atari attention masks...")
|
238 |
+
+ all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs)
|
239 |
+
+ else:
|
240 |
+
+ datarows = dataset['train'][all_row_idxs]
|
241 |
+
+ all_rows_of_obs_OG = np.array(datarows[obs_key])
|
242 |
+
+ all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool)
|
243 |
+
+ return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs
|
244 |
+
+
|
245 |
+
+def collect_subset(all_rows_of_obs_OG,
|
246 |
+
+ all_attn_masks_OG,
|
247 |
+
+ all_rows_to_consider,
|
248 |
+
+ kwargs
|
249 |
+
+ ):
|
250 |
+
+ """
|
251 |
+
+ Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return.
|
252 |
+
+ Used in both retrieve_atari() and retrieve_vector() --> build_index_vector().
|
253 |
+
+ """
|
254 |
+
+ myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...')
|
255 |
+
+ # read kwargs
|
256 |
+
+ B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim']
|
257 |
+
+
|
258 |
+
+ # take subset based on all_rows_to_consider
|
259 |
+
+ myprint(f'Taking subset of data based on all_rows_to_consider...')
|
260 |
+
+ all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider]
|
261 |
+
+ all_attn_masks = all_attn_masks_OG[all_rows_to_consider]
|
262 |
+
+ assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and
|
263 |
+
+ all_attn_masks.shape == (len(all_rows_to_consider), B))
|
264 |
+
+
|
265 |
+
+ # reshape
|
266 |
+
+ myprint(f'Reshaping data...')
|
267 |
+
+ all_attn_masks = all_attn_masks.reshape(-1)
|
268 |
+
+ all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim)
|
269 |
+
+ all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks]
|
270 |
+
+ assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and
|
271 |
+
+ all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim))
|
272 |
+
+
|
273 |
+
+ # collect indices of data
|
274 |
+
+ myprint(f'Collecting indices of data...')
|
275 |
+
+ all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
|
276 |
+
+ all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
|
277 |
+
+ assert all_indices.shape == (np.sum(all_attn_masks), 2)
|
278 |
+
+
|
279 |
+
+ myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}')
|
280 |
+
+ myprint(('-'*100) + '\n\n\n')
|
281 |
+
+ return all_indices, all_processed_rows_of_obs
|
282 |
+
+
|
283 |
+
+def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim)
|
284 |
+
+ all_processed_rows_of_obs,
|
285 |
+
+ all_indices,
|
286 |
+
+ num_to_retrieve,
|
287 |
+
kwargs
|
288 |
+
- ):
|
289 |
+
+ ):
|
290 |
+
+ """
|
291 |
+
+ Retrieval for Atari with images, ssim distance, and on GPU.
|
292 |
+
+ """
|
293 |
+
assert isinstance(row_of_obs, torch.Tensor)
|
294 |
+
|
295 |
+
# read kwargs # Note: B = len of row
|
296 |
+
- B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval']
|
297 |
+
+ B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval']
|
298 |
+
|
299 |
+
# batch size of row_of_obs which can be <= B since we process before calling this function
|
300 |
+
- row_B = row_of_obs.shape[0]
|
301 |
+
-
|
302 |
+
+ xbdim = row_of_obs.shape[0]
|
303 |
+
+
|
304 |
+
+ # collect subset of data that we can retrieve from
|
305 |
+
+ ydim = all_processed_rows_of_obs.shape[0]
|
306 |
+
+
|
307 |
+
# first argument for ssim
|
308 |
+
- repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device)
|
309 |
+
- assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84)
|
310 |
+
+ xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device)
|
311 |
+
+ assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84)
|
312 |
+
|
313 |
+
- # iterate over all other rows
|
314 |
+
+ # iterate over data that we can retrieve from in batches
|
315 |
+
all_ssim = []
|
316 |
+
- all_indices = []
|
317 |
+
- total = 0
|
318 |
+
- for other_row_idx in tqdm(all_rows_to_consider):
|
319 |
+
- other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key])
|
320 |
+
-
|
321 |
+
- # batch size of other_row_of_obs
|
322 |
+
- other_row_B = other_row_of_obs.shape[0]
|
323 |
+
- total += other_row_B
|
324 |
+
-
|
325 |
+
- # first argument for ssim: RECHECK
|
326 |
+
- if other_row_B < B: # when other row has less observations than expected
|
327 |
+
- repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device)
|
328 |
+
- elif other_row_B == B: # otherwise just use the one created before the for loop
|
329 |
+
- repeated_row = repeated_row_og
|
330 |
+
- assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84)
|
331 |
+
-
|
332 |
+
+ for j in range(0, ydim, batch_size_retrieval):
|
333 |
+
# second argument for ssim
|
334 |
+
- repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device)
|
335 |
+
- assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84)
|
336 |
+
+ ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval]
|
337 |
+
+ ybdim = ybatch.shape[0]
|
338 |
+
+ ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device)
|
339 |
+
+ assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84)
|
340 |
+
+
|
341 |
+
+ if ybdim < batch_size_retrieval: # for last batch
|
342 |
+
+ xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device)
|
343 |
+
+ assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84)
|
344 |
+
|
345 |
+
# compare via ssim and updated all_ssim
|
346 |
+
- ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False)
|
347 |
+
- ssim_score = ssim_score.reshape(row_B, other_row_B)
|
348 |
+
+ ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False)
|
349 |
+
+ ssim_score = ssim_score.reshape(xbdim, ybdim)
|
350 |
+
all_ssim.append(ssim_score)
|
351 |
+
|
352 |
+
- # update all_indices
|
353 |
+
- all_indices.extend([[other_row_idx, i] for i in range(other_row_B)])
|
354 |
+
-
|
355 |
+
# concat
|
356 |
+
all_ssim = torch.cat(all_ssim, dim=1)
|
357 |
+
- assert all_ssim.shape == (row_B, total)
|
358 |
+
+ assert all_ssim.shape == (xbdim, ydim)
|
359 |
+
|
360 |
+
- all_indices = np.array(all_indices)
|
361 |
+
- assert all_indices.shape == (total, 2)
|
362 |
+
+ assert all_indices.shape == (ydim, 2)
|
363 |
+
|
364 |
+
# get top-k indices
|
365 |
+
topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True)
|
366 |
+
topk_indices = topk_indices.cpu().numpy()
|
367 |
+
- assert topk_indices.shape == (row_B, num_to_retrieve)
|
368 |
+
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
|
369 |
+
|
370 |
+
# convert topk indices to indices in the dataset
|
371 |
+
- retrieved_indices = np.array(all_indices[topk_indices])
|
372 |
+
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
|
373 |
+
-
|
374 |
+
- # pad the above to expected B
|
375 |
+
- if row_B < B:
|
376 |
+
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
|
377 |
+
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
|
378 |
+
+ retrieved_indices = all_indices[topk_indices]
|
379 |
+
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
|
380 |
+
|
381 |
+
return retrieved_indices
|
382 |
+
|
383 |
+
-def build_index_vector(all_rows_of_obs_og,
|
384 |
+
- all_attn_masks_og,
|
385 |
+
+def build_index_vector(all_rows_of_obs_OG,
|
386 |
+
+ all_attn_masks_OG,
|
387 |
+
all_rows_to_consider,
|
388 |
+
kwargs
|
389 |
+
- ):
|
390 |
+
+ ):
|
391 |
+
+ """
|
392 |
+
+ Builds FAISS index for vector observation environments.
|
393 |
+
+ """
|
394 |
+
# read kwargs # Note: B = len of row
|
395 |
+
- B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss']
|
396 |
+
- obs_dim = get_obs_dim(task)
|
397 |
+
+ nb_cores_autofaiss = kwargs['nb_cores_autofaiss']
|
398 |
+
|
399 |
+
- # take subset based on all_rows_to_consider
|
400 |
+
- myprint(f'Taking subset')
|
401 |
+
- all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider]
|
402 |
+
- all_attn_masks = all_attn_masks_og[all_rows_to_consider]
|
403 |
+
- assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and
|
404 |
+
- all_attn_masks.shape == (len(all_rows_to_consider), B))
|
405 |
+
-
|
406 |
+
- # reshape
|
407 |
+
- all_attn_masks = all_attn_masks.reshape(-1)
|
408 |
+
- all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim)
|
409 |
+
- all_rows_of_obs = all_rows_of_obs[all_attn_masks]
|
410 |
+
- assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim)
|
411 |
+
+ # take subset based on all_rows_to_consider, reshape, and save indices of data
|
412 |
+
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs)
|
413 |
+
|
414 |
+
- # save indices of data to retrieve from
|
415 |
+
- myprint(f'Saving indices of data to retrieve from')
|
416 |
+
- all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
|
417 |
+
- all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
|
418 |
+
- assert all_indices.shape == (np.sum(all_attn_masks), 2)
|
419 |
+
+ # make sure input to build_index is float, otherwise you will get reading temp file error
|
420 |
+
+ all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float)
|
421 |
+
|
422 |
+
# build index
|
423 |
+
- myprint(f'Building index...')
|
424 |
+
- knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
|
425 |
+
+ myprint(('-'*100) + 'Building index...')
|
426 |
+
+ knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
|
427 |
+
save_on_disk=False,
|
428 |
+
min_nearest_neighbors_to_retrieve=20, # default: 20
|
429 |
+
max_index_query_time_ms=10, # default: 10
|
430 |
+
@@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og,
|
431 |
+
metric_type='l2',
|
432 |
+
nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command
|
433 |
+
)
|
434 |
+
+ myprint(('-'*100) + '\n\n\n')
|
435 |
+
|
436 |
+
- return knn_index, all_indices
|
437 |
+
+ return all_indices, knn_index
|
438 |
+
|
439 |
+
-def retrieve_vector(row_of_obs, # query: (row_B, dim)
|
440 |
+
- dataset, # to retrieve from
|
441 |
+
- all_rows_to_consider, # rows to consider
|
442 |
+
- num_to_retrieve, # top-k
|
443 |
+
+def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim)
|
444 |
+
+ knn_index,
|
445 |
+
+ all_indices,
|
446 |
+
+ num_to_retrieve,
|
447 |
+
kwargs
|
448 |
+
- ):
|
449 |
+
+ ):
|
450 |
+
+ """
|
451 |
+
+ Retrieval for vector observation environments.
|
452 |
+
+ """
|
453 |
+
assert isinstance(row_of_obs, np.ndarray)
|
454 |
+
|
455 |
+
# read few kwargs
|
456 |
+
B = kwargs['B']
|
457 |
+
|
458 |
+
# batch size of row_of_obs which can be <= B since we process before calling this function
|
459 |
+
- row_B = row_of_obs.shape[0]
|
460 |
+
+ xbdim = row_of_obs.shape[0]
|
461 |
+
|
462 |
+
- # read dataset_tuple
|
463 |
+
- all_rows_of_obs, all_attn_masks = dataset
|
464 |
+
-
|
465 |
+
- # create index and all_indices
|
466 |
+
- knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs)
|
467 |
+
-
|
468 |
+
# retrieve
|
469 |
+
myprint(f'Retrieving...')
|
470 |
+
topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve)
|
471 |
+
topk_indices = topk_indices.astype(int)
|
472 |
+
- assert topk_indices.shape == (row_B, 10 * num_to_retrieve)
|
473 |
+
+ assert topk_indices.shape == (xbdim, 10 * num_to_retrieve)
|
474 |
+
|
475 |
+
# remove -1s and crop to num_to_retrieve
|
476 |
+
try:
|
477 |
+
@@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim)
|
478 |
+
print(f'-------------------------------------------------------------------------------------------------------------------------------------------')
|
479 |
+
print(f'Leaving some -1s in topk_indices and continuing')
|
480 |
+
topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices])
|
481 |
+
- assert topk_indices.shape == (row_B, num_to_retrieve)
|
482 |
+
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
|
483 |
+
|
484 |
+
# convert topk indices to indices in the dataset
|
485 |
+
retrieved_indices = all_indices[topk_indices]
|
486 |
+
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
|
487 |
+
-
|
488 |
+
- # pad the above to expected B
|
489 |
+
- if row_B < B:
|
490 |
+
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
|
491 |
+
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
|
492 |
+
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
|
493 |
+
|
494 |
+
- myprint(f'Returning')
|
495 |
+
return retrieved_indices
|
496 |
+
|
497 |
+
diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py
|
498 |
+
index 07e545c..146b347 100755
|
499 |
+
--- a/scripts_regent/eval_RandP.py
|
500 |
+
+++ b/scripts_regent/eval_RandP.py
|
501 |
+
@@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser
|
502 |
+
|
503 |
+
from jat.eval.rl import TASK_NAME_TO_ENV_ID, make
|
504 |
+
from jat.utils import normalize, push_to_hub, save_video_grid
|
505 |
+
-from jat_regent.RandP import RandP
|
506 |
+
+from jat_regent.modeling_RandP import RandP
|
507 |
+
from datasets import load_from_disk
|
508 |
+
from datasets.config import HF_DATASETS_CACHE
|
509 |
+
+from jat_regent.utils import myprint
|
510 |
+
|
511 |
+
|
512 |
+
@dataclass
|
513 |
+
@@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args):
|
514 |
+
scores = []
|
515 |
+
frames = []
|
516 |
+
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
|
517 |
+
+ myprint(('-'*100) + f'{episode=}')
|
518 |
+
observation, _ = env.reset()
|
519 |
+
reward = None
|
520 |
+
rewards = []
|
521 |
+
@@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args):
|
522 |
+
frames.append(np.array(env.render(), dtype=np.uint8))
|
523 |
+
|
524 |
+
scores.append(sum(rewards))
|
525 |
+
+ myprint(('-'*100) + '\n\n\n')
|
526 |
+
env.close()
|
527 |
+
|
528 |
+
raw_mean, raw_std = np.mean(scores), np.std(scores)
|
529 |
+
@@ -145,7 +148,9 @@ def main():
|
530 |
+
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
|
531 |
+
|
532 |
+
device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
|
533 |
+
- processor = None
|
534 |
+
+ processor = AutoProcessor.from_pretrained(
|
535 |
+
+ 'jat-project/jat', cache_dir=None, trust_remote_code=True
|
536 |
+
+ )
|
537 |
+
|
538 |
+
evaluations = {}
|
539 |
+
video_list = []
|
540 |
+
@@ -153,14 +158,18 @@ def main():
|
541 |
+
|
542 |
+
for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
|
543 |
+
if task in TASK_NAME_TO_ENV_ID.keys():
|
544 |
+
+ myprint(('-'*100) + f'{task=}')
|
545 |
+
dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
|
546 |
+
- model = RandP(dataset)
|
547 |
+
+ model = RandP(task,
|
548 |
+
+ dataset,
|
549 |
+
+ device,)
|
550 |
+
scores, frames, fps = eval_rl(model, processor, task, eval_args)
|
551 |
+
evaluations[task] = scores
|
552 |
+
# Save the video
|
553 |
+
if eval_args.save_video:
|
554 |
+
video_list.append(frames)
|
555 |
+
input_fps.append(fps)
|
556 |
+
+ myprint(('-'*100) + '\n\n\n')
|
557 |
+
else:
|
558 |
+
warnings.warn(f"Task {task} is not supported.")
|
559 |
+
|
560 |
+
diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py
|
561 |
+
index c83d259..aad678a 100644
|
562 |
+
--- a/scripts_regent/offline_retrieval_jat_regent.py
|
563 |
+
+++ b/scripts_regent/offline_retrieval_jat_regent.py
|
564 |
+
@@ -8,7 +8,7 @@ import time
|
565 |
+
from datetime import datetime
|
566 |
+
from datasets import load_from_disk
|
567 |
+
from datasets.config import HF_DATASETS_CACHE
|
568 |
+
-from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector
|
569 |
+
+from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector
|
570 |
+
import logging
|
571 |
+
logging.basicConfig(level=logging.DEBUG)
|
572 |
+
|
573 |
+
@@ -17,7 +17,8 @@ def main():
|
574 |
+
parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices')
|
575 |
+
parser.add_argument('--task', type=str, default='atari-alien', help='Task name')
|
576 |
+
parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve')
|
577 |
+
- parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments')
|
578 |
+
+ parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs')
|
579 |
+
+ parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari')
|
580 |
+
args = parser.parse_args()
|
581 |
+
|
582 |
+
# load dataset, map, device, for task
|
583 |
+
@@ -25,77 +26,83 @@ def main():
|
584 |
+
dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"
|
585 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
586 |
+
|
587 |
+
- rew_key = 'rewards'
|
588 |
+
- attn_key = 'attention_mask'
|
589 |
+
- if task.startswith("atari"):
|
590 |
+
- obs_key = 'image_observations'
|
591 |
+
- act_key = 'discrete_actions'
|
592 |
+
- len_row_tokenized_known = 32 # half of 54
|
593 |
+
- process_row_fn = process_row_atari
|
594 |
+
- retrieve_fn = retrieve_atari
|
595 |
+
- elif task.startswith("babyai"):
|
596 |
+
- obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
|
597 |
+
- act_key = 'discrete_actions'
|
598 |
+
- len_row_tokenized_known = 256 # half of 512
|
599 |
+
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
|
600 |
+
- retrieve_fn = retrieve_vector
|
601 |
+
- elif task.startswith("metaworld") or task.startswith("mujoco"):
|
602 |
+
- obs_key = 'continuous_observations'
|
603 |
+
- act_key = 'continuous_actions'
|
604 |
+
- len_row_tokenized_known = 256
|
605 |
+
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
|
606 |
+
- retrieve_fn = retrieve_vector
|
607 |
+
+ rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task)
|
608 |
+
|
609 |
+
dataset = load_from_disk(dataset_path)
|
610 |
+
with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f:
|
611 |
+
map_from_rows_to_episodes_for_tokenized = json.load(f)
|
612 |
+
|
613 |
+
# setup kwargs
|
614 |
+
- len_dataset = len(dataset['train'])
|
615 |
+
- B = len_row_tokenized_known
|
616 |
+
kwargs = {'B': B,
|
617 |
+
- 'attn_key':attn_key,
|
618 |
+
- 'obs_key':obs_key,
|
619 |
+
- 'device':device,
|
620 |
+
- 'task':task,
|
621 |
+
- 'batch_size_retrieval':None,
|
622 |
+
- 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss,
|
623 |
+
- }
|
624 |
+
+ 'obs_dim': obs_dim,
|
625 |
+
+ 'attn_key': attn_key,
|
626 |
+
+ 'obs_key': obs_key,
|
627 |
+
+ 'device': device,
|
628 |
+
+ 'task': task,
|
629 |
+
+ 'batch_size_retrieval': args.batch_size_retrieval,
|
630 |
+
+ 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss,
|
631 |
+
+ }
|
632 |
+
|
633 |
+
# collect all observations in a single array (this takes some time) for vector observation environments
|
634 |
+
- if not task.startswith("atari"):
|
635 |
+
- myprint("Collecting all observations/attn_masks in a single array")
|
636 |
+
- all_rows_of_obs = np.array(dataset['train'][obs_key])
|
637 |
+
- all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool)
|
638 |
+
+ myprint("Collecting all observations/attn_masks")
|
639 |
+
+ all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key)
|
640 |
+
|
641 |
+
# iterate over rows
|
642 |
+
all_retrieved_indices = []
|
643 |
+
- for row_idx in range(len_dataset):
|
644 |
+
- myprint(f"\nProcessing row {row_idx}/{len_dataset}")
|
645 |
+
+ for row_idx in all_row_idxs:
|
646 |
+
+ myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}")
|
647 |
+
current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)]
|
648 |
+
|
649 |
+
- attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task)
|
650 |
+
+ # get row_of_obs and attn_mask
|
651 |
+
+ datarow = dataset['train'][row_idx]
|
652 |
+
+ attn_mask = np.array(datarow[attn_key]).astype(bool)
|
653 |
+
+ if task.startswith("atari"):
|
654 |
+
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key])
|
655 |
+
+ else:
|
656 |
+
+ row_of_obs = np.array(datarow[obs_key])
|
657 |
+
+ row_of_obs = row_of_obs[attn_mask]
|
658 |
+
+ assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim)
|
659 |
+
|
660 |
+
# compare with rows from all but the current episode
|
661 |
+
- all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
|
662 |
+
+ all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
|
663 |
+
|
664 |
+
# do the retrieval
|
665 |
+
- retrieved_indices = retrieve_fn(row_of_obs=row_of_obs,
|
666 |
+
- dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks),
|
667 |
+
- all_rows_to_consider=all_other_rows,
|
668 |
+
- num_to_retrieve=args.num_to_retrieve,
|
669 |
+
- kwargs=kwargs,
|
670 |
+
- )
|
671 |
+
+ if task.startswith("atari"):
|
672 |
+
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG,
|
673 |
+
+ all_attn_masks_OG=all_attn_masks_OG,
|
674 |
+
+ all_rows_to_consider=all_row_idxs,
|
675 |
+
+ kwargs=kwargs)
|
676 |
+
+ retrieved_indices = retrieve_atari(row_of_obs=row_of_obs,
|
677 |
+
+ all_processed_rows_of_obs=all_processed_rows_of_obs,
|
678 |
+
+ all_indices=all_indices,
|
679 |
+
+ num_to_retrieve=args.num_to_retrieve,
|
680 |
+
+ kwargs=kwargs)
|
681 |
+
+ else:
|
682 |
+
+ all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
|
683 |
+
+ all_attn_masks_OG=all_attn_masks_OG,
|
684 |
+
+ all_rows_to_consider=all_other_row_idxs,
|
685 |
+
+ kwargs=kwargs)
|
686 |
+
+ retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
|
687 |
+
+ knn_index=knn_index,
|
688 |
+
+ all_indices=all_indices,
|
689 |
+
+ num_to_retrieve=args.num_to_retrieve,
|
690 |
+
+ kwargs=kwargs)
|
691 |
+
+
|
692 |
+
+ # pad the above to expected B
|
693 |
+
+ xbdim = row_of_obs.shape[0]
|
694 |
+
+ if xbdim < B:
|
695 |
+
+ retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0)
|
696 |
+
+ assert retrieved_indices.shape == (B, args.num_to_retrieve, 2)
|
697 |
+
|
698 |
+
# collect retrieved indices
|
699 |
+
all_retrieved_indices.append(retrieved_indices)
|
700 |
+
|
701 |
+
# concat
|
702 |
+
all_retrieved_indices = np.stack(all_retrieved_indices, axis=0)
|
703 |
+
- assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2)
|
704 |
+
+ assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2)
|
705 |
+
|
706 |
+
# save arrays as bin for easy memmap access and faster loading
|
707 |
+
- all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin")
|
708 |
+
+ all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin")
|
709 |
+
|
710 |
+
if __name__ == "__main__":
|
711 |
+
main()
|
712 |
+
|
replay.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0216fee8d31be90d00225c2aa7422ba6dc889e245c70f8e0e627b32bdcce528
|
3 |
+
size 1003098
|
sf_log.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|