File size: 44,720 Bytes
d007384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
diff --git a/README.md b/README.md
index e51a12b..a6e1ca1 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,21 @@ conda activate jat
 pip install -e .[dev]
 ```
 
+## REGENT fork of sample-factory: Installation
+Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork:
+```shell
+git clone https://github.com/kaustubhsridhar/sample-factory.git
+cd sample-factory
+pip install -e .[dev,mujoco,atari,envpool,vizdoom]
+```
+
+# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets
+Train policies using envpool's atari:
+```shell
+bash scripts_sample-factory/train_unseen_atari.sh
+```
+Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2).
+
 ## PREV Installation
 
 To get started with JAT, follow these steps:
@@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS
 ```
 
 ### REGENT Analyze data
+Necessary:
 ```shell
-python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
-
 python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt &
+```
 
+Already ran and output dict in code:
+```shell
 python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt &
+
+python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt &
+```
+
+Optional:
+```shell
+python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
 ```
 
 ## PREV Dataset
diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py
deleted file mode 100644
index b2bd8bf..0000000
--- a/jat_regent/RandP.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import warnings
-from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from gymnasium import spaces
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
-from transformers import GPTNeoModel, GPTNeoPreTrainedModel
-from transformers.modeling_outputs import ModelOutput
-from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
-
-from jat.configuration_jat import JatConfig
-from jat.processing_jat import JatProcessor
-
-
-class RandP():
-    def __init__(self, dataset) -> None:
-        self.steps = 0
-        # create an index for retrieval in vector obs envs (OR) collect all images in Atari
-
-    def reset_rl(self):
-        self.steps = 0
-
-    def get_next_action(
-        self,
-        processor: JatProcessor,
-        continuous_observation: Optional[List[float]] = None,
-        discrete_observation: Optional[List[int]] = None,
-        text_observation: Optional[str] = None,
-        image_observation: Optional[np.ndarray] = None,
-        action_space: Union[spaces.Box, spaces.Discrete] = None,
-        reward: Optional[float] = None,
-        deterministic: bool = False,
-        context_window: Optional[int] = None,
-    ):
-        pass
\ No newline at end of file
diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py
deleted file mode 100644
index e69de29..0000000
diff --git a/jat_regent/utils.py b/jat_regent/utils.py
index 56bfb44..36f6cca 100644
--- a/jat_regent/utils.py
+++ b/jat_regent/utils.py
@@ -8,23 +8,35 @@ from tqdm import tqdm
 from autofaiss import build_index
 
 
+UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11
+    
+}
+
 def myprint(str):
-    # check if first character of string is a newline character
-    if str[0] == '\n':
-        str_without_newline = str[1:]
+    # check if first characters of string are newline character
+    num_newlines = 0
+    while str[num_newlines] == '\n':
         print()
-    else:
-        str_without_newline = str
+        num_newlines += 1
+    str_without_newline = str[num_newlines:]
     print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}:     {str_without_newline}')
 
 def is_png_img(item):
     return isinstance(item, PngImagePlugin.PngImageFile)
 
+def get_last_row_for_1M_states(task):
+    last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101}
+    return last_row_idx[task]
+
+def get_last_row_for_100k_states(task):
+    last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407}
+    return last_row_idx[task]
+
 def get_obs_dim(task):
     assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
 
     all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17}
-    return all_obs_dims[task]
+    return (all_obs_dims[task],)
 
 def get_act_dim(task):
     assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
@@ -36,141 +48,188 @@ def get_act_dim(task):
     elif task.startswith("mujoco"):
         all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6}
         return all_act_dims[task]
-
-def process_row_atari(attn_mask, row_of_obs, task):
-    """
-    Example for selection with bools:
-    >>> a = np.array([0,1,2,3,4,5])
-    >>> b = np.array([1,0,0,0,0,1]).astype(bool)
-    >>> a[b]
-    array([0, 5])
-    """
-    attn_mask = np.array(attn_mask).astype(bool)
     
-    row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
-    row_of_obs = row_of_obs[attn_mask]
+def get_task_info(task):
+    rew_key = 'rewards'
+    attn_key = 'attention_mask'
+    if task.startswith("atari"):
+        obs_key = 'image_observations'
+        act_key = 'discrete_actions'
+        B = 32 # half of 54
+        obs_dim = (3, 4*84, 84)
+    elif task.startswith("babyai"):
+        obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
+        act_key = 'discrete_actions'
+        B = 256 # half of 512
+        obs_dim = get_obs_dim(task)
+    elif task.startswith("metaworld") or task.startswith("mujoco"):
+        obs_key = 'continuous_observations'
+        act_key = 'continuous_actions'
+        B = 256
+        obs_dim = get_obs_dim(task)
+
+    return rew_key, attn_key, obs_key, act_key, B, obs_dim
+
+def process_row_of_obs_atari_full_without_mask(row_of_obs):
+
+    if not isinstance(row_of_obs, torch.Tensor):
+        row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])  
     row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1]
-    assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84)
+    assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84)
     row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84)
-    row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side
+    row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side
     row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels
-    assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
-    
-    return attn_mask, row_of_obs
+    assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
+
+    return row_of_obs
 
-def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False):
-    attn_mask = np.array(attn_mask).astype(bool)
+def collect_all_atari_data(dataset, all_row_idxs=None):
+    if all_row_idxs is None:
+        all_row_idxs = list(range(len(dataset['train'])))
     
-    row_of_obs = np.array(row_of_obs)
-    if not return_numpy:
-        row_of_obs = torch.tensor(row_of_obs)
-    row_of_obs = row_of_obs[attn_mask]
-    assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task))
-
-    return attn_mask, row_of_obs
-
-def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84)
-            dataset, # to retrieve from
-            all_rows_to_consider, # rows to consider
-            num_to_retrieve, # top-k
+    all_rows_of_obs = []
+    all_attn_masks = []
+    for row_idx in tqdm(all_row_idxs):
+        datarow = dataset['train'][row_idx]
+        row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations'])
+        attn_mask = np.array(datarow['attention_mask']).astype(bool)
+        all_rows_of_obs.append(row_of_obs) # appending tensor
+        all_attn_masks.append(attn_mask) # appending np array
+    all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors
+    all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays
+    assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and
+            all_attn_masks.shape == (len(all_row_idxs), 32))
+    return all_attn_masks, all_rows_of_obs
+
+def collect_all_data(dataset, task, obs_key):
+    last_row_idx = get_last_row_for_100k_states(task)
+    all_row_idxs = list(range(last_row_idx))
+    if task.startswith("atari"):
+        myprint("Collecting all Atari images and Atari attention masks...")
+        all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs)
+    else:
+        datarows = dataset['train'][all_row_idxs]
+        all_rows_of_obs_OG = np.array(datarows[obs_key])
+        all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool)
+    return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs
+
+def collect_subset(all_rows_of_obs_OG, 
+                   all_attn_masks_OG, 
+                   all_rows_to_consider, 
+                   kwargs
+    ):
+    """
+    Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return.
+    Used in both retrieve_atari() and retrieve_vector() --> build_index_vector().
+    """
+    myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...')
+    # read kwargs
+    B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim']
+
+    # take subset based on all_rows_to_consider
+    myprint(f'Taking subset of data based on all_rows_to_consider...')
+    all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider]
+    all_attn_masks = all_attn_masks_OG[all_rows_to_consider]
+    assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and
+            all_attn_masks.shape == (len(all_rows_to_consider), B))
+
+    # reshape
+    myprint(f'Reshaping data...')
+    all_attn_masks = all_attn_masks.reshape(-1)
+    all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim)
+    all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks]
+    assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and 
+            all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim))
+
+    # collect indices of data
+    myprint(f'Collecting indices of data...')
+    all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
+    all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
+    assert all_indices.shape == (np.sum(all_attn_masks), 2)
+
+    myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}')
+    myprint(('-'*100) + '\n\n\n')
+    return all_indices, all_processed_rows_of_obs
+
+def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim)
+            all_processed_rows_of_obs,
+            all_indices,
+            num_to_retrieve,
             kwargs
-            ):
+    ):
+    """
+    Retrieval for Atari with images, ssim distance, and on GPU.
+    """
     assert isinstance(row_of_obs, torch.Tensor)
 
     # read kwargs # Note: B = len of row
-    B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval']
+    B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval']
     
     # batch size of row_of_obs which can be <= B since we process before calling this function
-    row_B = row_of_obs.shape[0]
-    
+    xbdim = row_of_obs.shape[0]
+
+    # collect subset of data that we can retrieve from
+    ydim = all_processed_rows_of_obs.shape[0]
+
     # first argument for ssim
-    repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device)
-    assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84)
+    xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device)
+    assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84)
 
-    # iterate over all other rows
+    # iterate over data that we can retrieve from in batches
     all_ssim = []
-    all_indices = []
-    total = 0
-    for other_row_idx in tqdm(all_rows_to_consider):
-        other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key])
-        
-        # batch size of other_row_of_obs
-        other_row_B = other_row_of_obs.shape[0]
-        total += other_row_B
-
-        # first argument for ssim: RECHECK
-        if other_row_B < B: # when other row has less observations than expected
-            repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device)
-        elif other_row_B == B: # otherwise just use the one created before the for loop
-            repeated_row = repeated_row_og
-        assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84)
-
+    for j in range(0, ydim, batch_size_retrieval):
         # second argument for ssim
-        repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device)
-        assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84)
+        ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval]
+        ybdim = ybatch.shape[0]
+        ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device)
+        assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84)
+
+        if ybdim < batch_size_retrieval: # for last batch
+            xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device)
+        assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84)
 
         # compare via ssim and updated all_ssim
-        ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False)
-        ssim_score = ssim_score.reshape(row_B, other_row_B)
+        ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False)
+        ssim_score = ssim_score.reshape(xbdim, ybdim)
         all_ssim.append(ssim_score)
 
-        # update all_indices
-        all_indices.extend([[other_row_idx, i] for i in range(other_row_B)])
-
     # concat
     all_ssim = torch.cat(all_ssim, dim=1)
-    assert all_ssim.shape == (row_B, total)
+    assert all_ssim.shape == (xbdim, ydim)
 
-    all_indices = np.array(all_indices)
-    assert all_indices.shape == (total, 2)
+    assert all_indices.shape == (ydim, 2)
 
     # get top-k indices
     topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True)
     topk_indices = topk_indices.cpu().numpy()
-    assert topk_indices.shape == (row_B, num_to_retrieve)
+    assert topk_indices.shape == (xbdim, num_to_retrieve)
 
     # convert topk indices to indices in the dataset
-    retrieved_indices = np.array(all_indices[topk_indices]) 
-    assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
-
-    # pad the above to expected B
-    if row_B < B:
-        retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
-    assert retrieved_indices.shape == (B, num_to_retrieve, 2)
+    retrieved_indices = all_indices[topk_indices]
+    assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
 
     return retrieved_indices
 
-def build_index_vector(all_rows_of_obs_og, 
-                       all_attn_masks_og, 
+def build_index_vector(all_rows_of_obs_OG, 
+                       all_attn_masks_OG, 
                        all_rows_to_consider, 
                        kwargs
-                    ):
+    ):
+    """
+    Builds FAISS index for vector observation environments.
+    """
     # read kwargs # Note: B = len of row
-    B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss']
-    obs_dim = get_obs_dim(task)
+    nb_cores_autofaiss = kwargs['nb_cores_autofaiss']
 
-    # take subset based on all_rows_to_consider
-    myprint(f'Taking subset')
-    all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider]
-    all_attn_masks = all_attn_masks_og[all_rows_to_consider]
-    assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and
-            all_attn_masks.shape == (len(all_rows_to_consider), B))
-    
-    # reshape
-    all_attn_masks = all_attn_masks.reshape(-1)
-    all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim)
-    all_rows_of_obs = all_rows_of_obs[all_attn_masks]
-    assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim)
+    # take subset based on all_rows_to_consider, reshape, and save indices of data
+    all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs)
 
-    # save indices of data to retrieve from
-    myprint(f'Saving indices of data to retrieve from')
-    all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
-    all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
-    assert all_indices.shape == (np.sum(all_attn_masks), 2)
+    # make sure input to build_index is float, otherwise you will get reading temp file error
+    all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float)
 
     # build index
-    myprint(f'Building index...')
-    knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
+    myprint(('-'*100) + 'Building index...')
+    knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
                                             save_on_disk=False,
                                             min_nearest_neighbors_to_retrieve=20, # default: 20
                                             max_index_query_time_ms=10, # default: 10
@@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og,
                                             metric_type='l2',
                                             nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command
                                             )
+    myprint(('-'*100) + '\n\n\n')
     
-    return knn_index, all_indices
+    return all_indices, knn_index
 
-def retrieve_vector(row_of_obs, # query: (row_B, dim)
-            dataset, # to retrieve from
-            all_rows_to_consider, # rows to consider
-            num_to_retrieve, # top-k
+def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim)
+            knn_index,
+            all_indices,
+            num_to_retrieve,
             kwargs
-            ):
+    ):
+    """
+    Retrieval for vector observation environments.
+    """
     assert isinstance(row_of_obs, np.ndarray)
 
     # read few kwargs
     B = kwargs['B']
 
     # batch size of row_of_obs which can be <= B since we process before calling this function
-    row_B = row_of_obs.shape[0]
+    xbdim = row_of_obs.shape[0]
 
-    # read dataset_tuple
-    all_rows_of_obs, all_attn_masks = dataset
-
-    # create index and all_indices
-    knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs)
-    
     # retrieve
     myprint(f'Retrieving...')
     topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve)
     topk_indices = topk_indices.astype(int)
-    assert topk_indices.shape == (row_B, 10 * num_to_retrieve)
+    assert topk_indices.shape == (xbdim, 10 * num_to_retrieve)
 
     # remove -1s and crop to num_to_retrieve
     try:
@@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim)
         print(f'-------------------------------------------------------------------------------------------------------------------------------------------')
         print(f'Leaving some -1s in topk_indices and continuing')
         topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices])
-    assert topk_indices.shape == (row_B, num_to_retrieve)
+    assert topk_indices.shape == (xbdim, num_to_retrieve)
 
     # convert topk indices to indices in the dataset
     retrieved_indices = all_indices[topk_indices]
-    assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
-
-    # pad the above to expected B
-    if row_B < B:
-        retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
-    assert retrieved_indices.shape == (B, num_to_retrieve, 2)
+    assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
 
-    myprint(f'Returning')
     return retrieved_indices
\ No newline at end of file
diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py
index 07e545c..146b347 100755
--- a/scripts_regent/eval_RandP.py
+++ b/scripts_regent/eval_RandP.py
@@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser
 
 from jat.eval.rl import TASK_NAME_TO_ENV_ID, make
 from jat.utils import normalize, push_to_hub, save_video_grid
-from jat_regent.RandP import RandP
+from jat_regent.modeling_RandP import RandP
 from datasets import load_from_disk
 from datasets.config import HF_DATASETS_CACHE
+from jat_regent.utils import myprint
 
 
 @dataclass
@@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args):
     scores = []
     frames = []
     for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
+        myprint(('-'*100) + f'{episode=}')
         observation, _ = env.reset()
         reward = None
         rewards = []
@@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args):
                 frames.append(np.array(env.render(), dtype=np.uint8))
 
         scores.append(sum(rewards))
+        myprint(('-'*100) + '\n\n\n')
     env.close()
 
     raw_mean, raw_std = np.mean(scores), np.std(scores)
@@ -145,7 +148,9 @@ def main():
             tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
 
     device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
-    processor = None
+    processor = AutoProcessor.from_pretrained(
+        'jat-project/jat', cache_dir=None, trust_remote_code=True
+    )
 
     evaluations = {}
     video_list = []
@@ -153,14 +158,18 @@ def main():
 
     for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
         if task in TASK_NAME_TO_ENV_ID.keys():
+            myprint(('-'*100) + f'{task=}')
             dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
-            model = RandP(dataset)
+            model = RandP(task, 
+                          dataset, 
+                          device,)
             scores, frames, fps = eval_rl(model, processor, task, eval_args)
             evaluations[task] = scores
             # Save the video
             if eval_args.save_video:
                 video_list.append(frames)
                 input_fps.append(fps)
+            myprint(('-'*100) + '\n\n\n')
         else:
             warnings.warn(f"Task {task} is not supported.")
 
diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py
index c83d259..aad678a 100644
--- a/scripts_regent/offline_retrieval_jat_regent.py
+++ b/scripts_regent/offline_retrieval_jat_regent.py
@@ -8,7 +8,7 @@ import time
 from datetime import datetime
 from datasets import load_from_disk
 from datasets.config import HF_DATASETS_CACHE
-from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector
+from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector
 import logging
 logging.basicConfig(level=logging.DEBUG)
 
@@ -17,7 +17,8 @@ def main():
     parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices')
     parser.add_argument('--task', type=str, default='atari-alien', help='Task name')
     parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve')
-    parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments')
+    parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs')
+    parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari')
     args = parser.parse_args()
 
     # load dataset, map, device, for task
@@ -25,77 +26,83 @@ def main():
     dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-    rew_key = 'rewards'
-    attn_key = 'attention_mask'
-    if task.startswith("atari"):
-        obs_key = 'image_observations'
-        act_key = 'discrete_actions'
-        len_row_tokenized_known = 32 # half of 54
-        process_row_fn = process_row_atari
-        retrieve_fn = retrieve_atari
-    elif task.startswith("babyai"):
-        obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
-        act_key = 'discrete_actions'
-        len_row_tokenized_known = 256 # half of 512
-        process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
-        retrieve_fn = retrieve_vector
-    elif task.startswith("metaworld") or task.startswith("mujoco"):
-        obs_key = 'continuous_observations'
-        act_key = 'continuous_actions'
-        len_row_tokenized_known = 256
-        process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
-        retrieve_fn = retrieve_vector
+    rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task)
 
     dataset = load_from_disk(dataset_path)
     with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f:
         map_from_rows_to_episodes_for_tokenized = json.load(f)
 
     # setup kwargs
-    len_dataset = len(dataset['train'])
-    B = len_row_tokenized_known
     kwargs = {'B': B, 
-            'attn_key':attn_key, 
-            'obs_key':obs_key, 
-            'device':device,
-            'task':task,
-            'batch_size_retrieval':None,
-            'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss,
-        }
+              'obs_dim': obs_dim,
+              'attn_key': attn_key, 
+              'obs_key': obs_key, 
+              'device': device,
+              'task': task,
+              'batch_size_retrieval': args.batch_size_retrieval,
+              'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss,
+            }
 
     # collect all observations in a single array (this takes some time) for vector observation environments
-    if not task.startswith("atari"):
-        myprint("Collecting all observations/attn_masks in a single array")
-        all_rows_of_obs = np.array(dataset['train'][obs_key])
-        all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool)
+    myprint("Collecting all observations/attn_masks")
+    all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key)
 
     # iterate over rows
     all_retrieved_indices = []
-    for row_idx in range(len_dataset):
-        myprint(f"\nProcessing row {row_idx}/{len_dataset}")
+    for row_idx in all_row_idxs:
+        myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}")
         current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)]
 
-        attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task)
+        # get row_of_obs and attn_mask
+        datarow = dataset['train'][row_idx]
+        attn_mask = np.array(datarow[attn_key]).astype(bool)
+        if task.startswith("atari"):
+            row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key])
+        else:
+            row_of_obs = np.array(datarow[obs_key])
+        row_of_obs = row_of_obs[attn_mask]
+        assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim)
         
         # compare with rows from all but the current episode
-        all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
+        all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
 
         # do the retrieval
-        retrieved_indices = retrieve_fn(row_of_obs=row_of_obs, 
-                                        dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks),
-                                        all_rows_to_consider=all_other_rows, 
-                                        num_to_retrieve=args.num_to_retrieve, 
-                                        kwargs=kwargs,
-                                        )
+        if task.startswith("atari"):
+            all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG,
+                                                                    all_attn_masks_OG=all_attn_masks_OG,
+                                                                    all_rows_to_consider=all_row_idxs,
+                                                                    kwargs=kwargs)
+            retrieved_indices = retrieve_atari(row_of_obs=row_of_obs,
+                                               all_processed_rows_of_obs=all_processed_rows_of_obs,
+                                               all_indices=all_indices,
+                                               num_to_retrieve=args.num_to_retrieve,
+                                               kwargs=kwargs)
+        else:
+            all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG, 
+                                                        all_attn_masks_OG=all_attn_masks_OG, 
+                                                        all_rows_to_consider=all_other_row_idxs,
+                                                        kwargs=kwargs)
+            retrieved_indices = retrieve_vector(row_of_obs=row_of_obs, 
+                                                knn_index=knn_index, 
+                                                all_indices=all_indices, 
+                                                num_to_retrieve=args.num_to_retrieve,
+                                                kwargs=kwargs)
+            
+        # pad the above to expected B
+        xbdim = row_of_obs.shape[0]
+        if xbdim < B:
+            retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0)
+        assert retrieved_indices.shape == (B, args.num_to_retrieve, 2)
 
         # collect retrieved indices
         all_retrieved_indices.append(retrieved_indices)
 
     # concat
     all_retrieved_indices = np.stack(all_retrieved_indices, axis=0)
-    assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2)
+    assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2)
 
     # save arrays as bin for easy memmap access and faster loading
-    all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin")
+    all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin")
 
 if __name__ == "__main__":
     main()
\ No newline at end of file