File size: 44,720 Bytes
d007384 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 |
diff --git a/README.md b/README.md
index e51a12b..a6e1ca1 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,21 @@ conda activate jat
pip install -e .[dev]
```
+## REGENT fork of sample-factory: Installation
+Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork:
+```shell
+git clone https://github.com/kaustubhsridhar/sample-factory.git
+cd sample-factory
+pip install -e .[dev,mujoco,atari,envpool,vizdoom]
+```
+
+# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets
+Train policies using envpool's atari:
+```shell
+bash scripts_sample-factory/train_unseen_atari.sh
+```
+Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2).
+
## PREV Installation
To get started with JAT, follow these steps:
@@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS
```
### REGENT Analyze data
+Necessary:
```shell
-python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
-
python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt &
+```
+Already ran and output dict in code:
+```shell
python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt &
+
+python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt &
+```
+
+Optional:
+```shell
+python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
```
## PREV Dataset
diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py
deleted file mode 100644
index b2bd8bf..0000000
--- a/jat_regent/RandP.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import warnings
-from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from gymnasium import spaces
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
-from transformers import GPTNeoModel, GPTNeoPreTrainedModel
-from transformers.modeling_outputs import ModelOutput
-from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
-
-from jat.configuration_jat import JatConfig
-from jat.processing_jat import JatProcessor
-
-
-class RandP():
- def __init__(self, dataset) -> None:
- self.steps = 0
- # create an index for retrieval in vector obs envs (OR) collect all images in Atari
-
- def reset_rl(self):
- self.steps = 0
-
- def get_next_action(
- self,
- processor: JatProcessor,
- continuous_observation: Optional[List[float]] = None,
- discrete_observation: Optional[List[int]] = None,
- text_observation: Optional[str] = None,
- image_observation: Optional[np.ndarray] = None,
- action_space: Union[spaces.Box, spaces.Discrete] = None,
- reward: Optional[float] = None,
- deterministic: bool = False,
- context_window: Optional[int] = None,
- ):
- pass
\ No newline at end of file
diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py
deleted file mode 100644
index e69de29..0000000
diff --git a/jat_regent/utils.py b/jat_regent/utils.py
index 56bfb44..36f6cca 100644
--- a/jat_regent/utils.py
+++ b/jat_regent/utils.py
@@ -8,23 +8,35 @@ from tqdm import tqdm
from autofaiss import build_index
+UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11
+
+}
+
def myprint(str):
- # check if first character of string is a newline character
- if str[0] == '\n':
- str_without_newline = str[1:]
+ # check if first characters of string are newline character
+ num_newlines = 0
+ while str[num_newlines] == '\n':
print()
- else:
- str_without_newline = str
+ num_newlines += 1
+ str_without_newline = str[num_newlines:]
print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}')
def is_png_img(item):
return isinstance(item, PngImagePlugin.PngImageFile)
+def get_last_row_for_1M_states(task):
+ last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101}
+ return last_row_idx[task]
+
+def get_last_row_for_100k_states(task):
+ last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407}
+ return last_row_idx[task]
+
def get_obs_dim(task):
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17}
- return all_obs_dims[task]
+ return (all_obs_dims[task],)
def get_act_dim(task):
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
@@ -36,141 +48,188 @@ def get_act_dim(task):
elif task.startswith("mujoco"):
all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6}
return all_act_dims[task]
-
-def process_row_atari(attn_mask, row_of_obs, task):
- """
- Example for selection with bools:
- >>> a = np.array([0,1,2,3,4,5])
- >>> b = np.array([1,0,0,0,0,1]).astype(bool)
- >>> a[b]
- array([0, 5])
- """
- attn_mask = np.array(attn_mask).astype(bool)
- row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
- row_of_obs = row_of_obs[attn_mask]
+def get_task_info(task):
+ rew_key = 'rewards'
+ attn_key = 'attention_mask'
+ if task.startswith("atari"):
+ obs_key = 'image_observations'
+ act_key = 'discrete_actions'
+ B = 32 # half of 54
+ obs_dim = (3, 4*84, 84)
+ elif task.startswith("babyai"):
+ obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
+ act_key = 'discrete_actions'
+ B = 256 # half of 512
+ obs_dim = get_obs_dim(task)
+ elif task.startswith("metaworld") or task.startswith("mujoco"):
+ obs_key = 'continuous_observations'
+ act_key = 'continuous_actions'
+ B = 256
+ obs_dim = get_obs_dim(task)
+
+ return rew_key, attn_key, obs_key, act_key, B, obs_dim
+
+def process_row_of_obs_atari_full_without_mask(row_of_obs):
+
+ if not isinstance(row_of_obs, torch.Tensor):
+ row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1]
- assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84)
+ assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84)
row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84)
- row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side
+ row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side
row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels
- assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
-
- return attn_mask, row_of_obs
+ assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
+
+ return row_of_obs
-def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False):
- attn_mask = np.array(attn_mask).astype(bool)
+def collect_all_atari_data(dataset, all_row_idxs=None):
+ if all_row_idxs is None:
+ all_row_idxs = list(range(len(dataset['train'])))
- row_of_obs = np.array(row_of_obs)
- if not return_numpy:
- row_of_obs = torch.tensor(row_of_obs)
- row_of_obs = row_of_obs[attn_mask]
- assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task))
-
- return attn_mask, row_of_obs
-
-def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84)
- dataset, # to retrieve from
- all_rows_to_consider, # rows to consider
- num_to_retrieve, # top-k
+ all_rows_of_obs = []
+ all_attn_masks = []
+ for row_idx in tqdm(all_row_idxs):
+ datarow = dataset['train'][row_idx]
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations'])
+ attn_mask = np.array(datarow['attention_mask']).astype(bool)
+ all_rows_of_obs.append(row_of_obs) # appending tensor
+ all_attn_masks.append(attn_mask) # appending np array
+ all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors
+ all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays
+ assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and
+ all_attn_masks.shape == (len(all_row_idxs), 32))
+ return all_attn_masks, all_rows_of_obs
+
+def collect_all_data(dataset, task, obs_key):
+ last_row_idx = get_last_row_for_100k_states(task)
+ all_row_idxs = list(range(last_row_idx))
+ if task.startswith("atari"):
+ myprint("Collecting all Atari images and Atari attention masks...")
+ all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs)
+ else:
+ datarows = dataset['train'][all_row_idxs]
+ all_rows_of_obs_OG = np.array(datarows[obs_key])
+ all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool)
+ return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs
+
+def collect_subset(all_rows_of_obs_OG,
+ all_attn_masks_OG,
+ all_rows_to_consider,
+ kwargs
+ ):
+ """
+ Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return.
+ Used in both retrieve_atari() and retrieve_vector() --> build_index_vector().
+ """
+ myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...')
+ # read kwargs
+ B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim']
+
+ # take subset based on all_rows_to_consider
+ myprint(f'Taking subset of data based on all_rows_to_consider...')
+ all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider]
+ all_attn_masks = all_attn_masks_OG[all_rows_to_consider]
+ assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and
+ all_attn_masks.shape == (len(all_rows_to_consider), B))
+
+ # reshape
+ myprint(f'Reshaping data...')
+ all_attn_masks = all_attn_masks.reshape(-1)
+ all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim)
+ all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks]
+ assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and
+ all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim))
+
+ # collect indices of data
+ myprint(f'Collecting indices of data...')
+ all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
+ all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
+ assert all_indices.shape == (np.sum(all_attn_masks), 2)
+
+ myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}')
+ myprint(('-'*100) + '\n\n\n')
+ return all_indices, all_processed_rows_of_obs
+
+def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim)
+ all_processed_rows_of_obs,
+ all_indices,
+ num_to_retrieve,
kwargs
- ):
+ ):
+ """
+ Retrieval for Atari with images, ssim distance, and on GPU.
+ """
assert isinstance(row_of_obs, torch.Tensor)
# read kwargs # Note: B = len of row
- B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval']
+ B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval']
# batch size of row_of_obs which can be <= B since we process before calling this function
- row_B = row_of_obs.shape[0]
-
+ xbdim = row_of_obs.shape[0]
+
+ # collect subset of data that we can retrieve from
+ ydim = all_processed_rows_of_obs.shape[0]
+
# first argument for ssim
- repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device)
- assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84)
+ xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device)
+ assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84)
- # iterate over all other rows
+ # iterate over data that we can retrieve from in batches
all_ssim = []
- all_indices = []
- total = 0
- for other_row_idx in tqdm(all_rows_to_consider):
- other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key])
-
- # batch size of other_row_of_obs
- other_row_B = other_row_of_obs.shape[0]
- total += other_row_B
-
- # first argument for ssim: RECHECK
- if other_row_B < B: # when other row has less observations than expected
- repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device)
- elif other_row_B == B: # otherwise just use the one created before the for loop
- repeated_row = repeated_row_og
- assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84)
-
+ for j in range(0, ydim, batch_size_retrieval):
# second argument for ssim
- repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device)
- assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84)
+ ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval]
+ ybdim = ybatch.shape[0]
+ ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device)
+ assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84)
+
+ if ybdim < batch_size_retrieval: # for last batch
+ xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device)
+ assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84)
# compare via ssim and updated all_ssim
- ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False)
- ssim_score = ssim_score.reshape(row_B, other_row_B)
+ ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False)
+ ssim_score = ssim_score.reshape(xbdim, ybdim)
all_ssim.append(ssim_score)
- # update all_indices
- all_indices.extend([[other_row_idx, i] for i in range(other_row_B)])
-
# concat
all_ssim = torch.cat(all_ssim, dim=1)
- assert all_ssim.shape == (row_B, total)
+ assert all_ssim.shape == (xbdim, ydim)
- all_indices = np.array(all_indices)
- assert all_indices.shape == (total, 2)
+ assert all_indices.shape == (ydim, 2)
# get top-k indices
topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True)
topk_indices = topk_indices.cpu().numpy()
- assert topk_indices.shape == (row_B, num_to_retrieve)
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
# convert topk indices to indices in the dataset
- retrieved_indices = np.array(all_indices[topk_indices])
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
-
- # pad the above to expected B
- if row_B < B:
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
+ retrieved_indices = all_indices[topk_indices]
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
return retrieved_indices
-def build_index_vector(all_rows_of_obs_og,
- all_attn_masks_og,
+def build_index_vector(all_rows_of_obs_OG,
+ all_attn_masks_OG,
all_rows_to_consider,
kwargs
- ):
+ ):
+ """
+ Builds FAISS index for vector observation environments.
+ """
# read kwargs # Note: B = len of row
- B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss']
- obs_dim = get_obs_dim(task)
+ nb_cores_autofaiss = kwargs['nb_cores_autofaiss']
- # take subset based on all_rows_to_consider
- myprint(f'Taking subset')
- all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider]
- all_attn_masks = all_attn_masks_og[all_rows_to_consider]
- assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and
- all_attn_masks.shape == (len(all_rows_to_consider), B))
-
- # reshape
- all_attn_masks = all_attn_masks.reshape(-1)
- all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim)
- all_rows_of_obs = all_rows_of_obs[all_attn_masks]
- assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim)
+ # take subset based on all_rows_to_consider, reshape, and save indices of data
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs)
- # save indices of data to retrieve from
- myprint(f'Saving indices of data to retrieve from')
- all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
- all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
- assert all_indices.shape == (np.sum(all_attn_masks), 2)
+ # make sure input to build_index is float, otherwise you will get reading temp file error
+ all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float)
# build index
- myprint(f'Building index...')
- knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
+ myprint(('-'*100) + 'Building index...')
+ knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
save_on_disk=False,
min_nearest_neighbors_to_retrieve=20, # default: 20
max_index_query_time_ms=10, # default: 10
@@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og,
metric_type='l2',
nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command
)
+ myprint(('-'*100) + '\n\n\n')
- return knn_index, all_indices
+ return all_indices, knn_index
-def retrieve_vector(row_of_obs, # query: (row_B, dim)
- dataset, # to retrieve from
- all_rows_to_consider, # rows to consider
- num_to_retrieve, # top-k
+def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim)
+ knn_index,
+ all_indices,
+ num_to_retrieve,
kwargs
- ):
+ ):
+ """
+ Retrieval for vector observation environments.
+ """
assert isinstance(row_of_obs, np.ndarray)
# read few kwargs
B = kwargs['B']
# batch size of row_of_obs which can be <= B since we process before calling this function
- row_B = row_of_obs.shape[0]
+ xbdim = row_of_obs.shape[0]
- # read dataset_tuple
- all_rows_of_obs, all_attn_masks = dataset
-
- # create index and all_indices
- knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs)
-
# retrieve
myprint(f'Retrieving...')
topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve)
topk_indices = topk_indices.astype(int)
- assert topk_indices.shape == (row_B, 10 * num_to_retrieve)
+ assert topk_indices.shape == (xbdim, 10 * num_to_retrieve)
# remove -1s and crop to num_to_retrieve
try:
@@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim)
print(f'-------------------------------------------------------------------------------------------------------------------------------------------')
print(f'Leaving some -1s in topk_indices and continuing')
topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices])
- assert topk_indices.shape == (row_B, num_to_retrieve)
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
# convert topk indices to indices in the dataset
retrieved_indices = all_indices[topk_indices]
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
-
- # pad the above to expected B
- if row_B < B:
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
- myprint(f'Returning')
return retrieved_indices
\ No newline at end of file
diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py
index 07e545c..146b347 100755
--- a/scripts_regent/eval_RandP.py
+++ b/scripts_regent/eval_RandP.py
@@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser
from jat.eval.rl import TASK_NAME_TO_ENV_ID, make
from jat.utils import normalize, push_to_hub, save_video_grid
-from jat_regent.RandP import RandP
+from jat_regent.modeling_RandP import RandP
from datasets import load_from_disk
from datasets.config import HF_DATASETS_CACHE
+from jat_regent.utils import myprint
@dataclass
@@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args):
scores = []
frames = []
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
+ myprint(('-'*100) + f'{episode=}')
observation, _ = env.reset()
reward = None
rewards = []
@@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args):
frames.append(np.array(env.render(), dtype=np.uint8))
scores.append(sum(rewards))
+ myprint(('-'*100) + '\n\n\n')
env.close()
raw_mean, raw_std = np.mean(scores), np.std(scores)
@@ -145,7 +148,9 @@ def main():
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
- processor = None
+ processor = AutoProcessor.from_pretrained(
+ 'jat-project/jat', cache_dir=None, trust_remote_code=True
+ )
evaluations = {}
video_list = []
@@ -153,14 +158,18 @@ def main():
for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
if task in TASK_NAME_TO_ENV_ID.keys():
+ myprint(('-'*100) + f'{task=}')
dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
- model = RandP(dataset)
+ model = RandP(task,
+ dataset,
+ device,)
scores, frames, fps = eval_rl(model, processor, task, eval_args)
evaluations[task] = scores
# Save the video
if eval_args.save_video:
video_list.append(frames)
input_fps.append(fps)
+ myprint(('-'*100) + '\n\n\n')
else:
warnings.warn(f"Task {task} is not supported.")
diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py
index c83d259..aad678a 100644
--- a/scripts_regent/offline_retrieval_jat_regent.py
+++ b/scripts_regent/offline_retrieval_jat_regent.py
@@ -8,7 +8,7 @@ import time
from datetime import datetime
from datasets import load_from_disk
from datasets.config import HF_DATASETS_CACHE
-from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector
+from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector
import logging
logging.basicConfig(level=logging.DEBUG)
@@ -17,7 +17,8 @@ def main():
parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices')
parser.add_argument('--task', type=str, default='atari-alien', help='Task name')
parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve')
- parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments')
+ parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs')
+ parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari')
args = parser.parse_args()
# load dataset, map, device, for task
@@ -25,77 +26,83 @@ def main():
dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- rew_key = 'rewards'
- attn_key = 'attention_mask'
- if task.startswith("atari"):
- obs_key = 'image_observations'
- act_key = 'discrete_actions'
- len_row_tokenized_known = 32 # half of 54
- process_row_fn = process_row_atari
- retrieve_fn = retrieve_atari
- elif task.startswith("babyai"):
- obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
- act_key = 'discrete_actions'
- len_row_tokenized_known = 256 # half of 512
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
- retrieve_fn = retrieve_vector
- elif task.startswith("metaworld") or task.startswith("mujoco"):
- obs_key = 'continuous_observations'
- act_key = 'continuous_actions'
- len_row_tokenized_known = 256
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
- retrieve_fn = retrieve_vector
+ rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task)
dataset = load_from_disk(dataset_path)
with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f:
map_from_rows_to_episodes_for_tokenized = json.load(f)
# setup kwargs
- len_dataset = len(dataset['train'])
- B = len_row_tokenized_known
kwargs = {'B': B,
- 'attn_key':attn_key,
- 'obs_key':obs_key,
- 'device':device,
- 'task':task,
- 'batch_size_retrieval':None,
- 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss,
- }
+ 'obs_dim': obs_dim,
+ 'attn_key': attn_key,
+ 'obs_key': obs_key,
+ 'device': device,
+ 'task': task,
+ 'batch_size_retrieval': args.batch_size_retrieval,
+ 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss,
+ }
# collect all observations in a single array (this takes some time) for vector observation environments
- if not task.startswith("atari"):
- myprint("Collecting all observations/attn_masks in a single array")
- all_rows_of_obs = np.array(dataset['train'][obs_key])
- all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool)
+ myprint("Collecting all observations/attn_masks")
+ all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key)
# iterate over rows
all_retrieved_indices = []
- for row_idx in range(len_dataset):
- myprint(f"\nProcessing row {row_idx}/{len_dataset}")
+ for row_idx in all_row_idxs:
+ myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}")
current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)]
- attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task)
+ # get row_of_obs and attn_mask
+ datarow = dataset['train'][row_idx]
+ attn_mask = np.array(datarow[attn_key]).astype(bool)
+ if task.startswith("atari"):
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key])
+ else:
+ row_of_obs = np.array(datarow[obs_key])
+ row_of_obs = row_of_obs[attn_mask]
+ assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim)
# compare with rows from all but the current episode
- all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
+ all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
# do the retrieval
- retrieved_indices = retrieve_fn(row_of_obs=row_of_obs,
- dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks),
- all_rows_to_consider=all_other_rows,
- num_to_retrieve=args.num_to_retrieve,
- kwargs=kwargs,
- )
+ if task.startswith("atari"):
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG,
+ all_attn_masks_OG=all_attn_masks_OG,
+ all_rows_to_consider=all_row_idxs,
+ kwargs=kwargs)
+ retrieved_indices = retrieve_atari(row_of_obs=row_of_obs,
+ all_processed_rows_of_obs=all_processed_rows_of_obs,
+ all_indices=all_indices,
+ num_to_retrieve=args.num_to_retrieve,
+ kwargs=kwargs)
+ else:
+ all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
+ all_attn_masks_OG=all_attn_masks_OG,
+ all_rows_to_consider=all_other_row_idxs,
+ kwargs=kwargs)
+ retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
+ knn_index=knn_index,
+ all_indices=all_indices,
+ num_to_retrieve=args.num_to_retrieve,
+ kwargs=kwargs)
+
+ # pad the above to expected B
+ xbdim = row_of_obs.shape[0]
+ if xbdim < B:
+ retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0)
+ assert retrieved_indices.shape == (B, args.num_to_retrieve, 2)
# collect retrieved indices
all_retrieved_indices.append(retrieved_indices)
# concat
all_retrieved_indices = np.stack(all_retrieved_indices, axis=0)
- assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2)
+ assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2)
# save arrays as bin for easy memmap access and faster loading
- all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin")
+ all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin")
if __name__ == "__main__":
main()
\ No newline at end of file
|