|
|
|
|
|
|
|
|
|
@@ -21,6 +21,21 @@ conda activate jat |
|
pip install -e .[dev] |
|
``` |
|
|
|
+## REGENT fork of sample-factory: Installation |
|
+Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork: |
|
+```shell |
|
+git clone https://github.com/kaustubhsridhar/sample-factory.git |
|
+cd sample-factory |
|
+pip install -e .[dev,mujoco,atari,envpool,vizdoom] |
|
+``` |
|
+ |
|
+# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets |
|
+Train policies using envpool's atari: |
|
+```shell |
|
+bash scripts_sample-factory/train_unseen_atari.sh |
|
+``` |
|
+Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2). |
|
+ |
|
## PREV Installation |
|
|
|
To get started with JAT, follow these steps: |
|
@@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS |
|
``` |
|
|
|
### REGENT Analyze data |
|
+Necessary: |
|
```shell |
|
-python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt & |
|
- |
|
python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt & |
|
+``` |
|
|
|
+Already ran and output dict in code: |
|
+```shell |
|
python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt & |
|
+ |
|
+python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt & |
|
+``` |
|
+ |
|
+Optional: |
|
+```shell |
|
+python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt & |
|
``` |
|
|
|
## PREV Dataset |
|
|
|
deleted file mode 100644 |
|
|
|
|
|
|
|
@@ -1,38 +0,0 @@ |
|
-import warnings |
|
-from dataclasses import dataclass |
|
-from typing import List, Optional, Tuple, Union |
|
- |
|
-import numpy as np |
|
-import torch |
|
-import torch.nn.functional as F |
|
-from gymnasium import spaces |
|
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn |
|
-from transformers import GPTNeoModel, GPTNeoPreTrainedModel |
|
-from transformers.modeling_outputs import ModelOutput |
|
-from transformers.models.vit.modeling_vit import ViTPatchEmbeddings |
|
- |
|
-from jat.configuration_jat import JatConfig |
|
-from jat.processing_jat import JatProcessor |
|
- |
|
- |
|
-class RandP(): |
|
- def __init__(self, dataset) -> None: |
|
- self.steps = 0 |
|
- # create an index for retrieval in vector obs envs (OR) collect all images in Atari |
|
- |
|
- def reset_rl(self): |
|
- self.steps = 0 |
|
- |
|
- def get_next_action( |
|
- self, |
|
- processor: JatProcessor, |
|
- continuous_observation: Optional[List[float]] = None, |
|
- discrete_observation: Optional[List[int]] = None, |
|
- text_observation: Optional[str] = None, |
|
- image_observation: Optional[np.ndarray] = None, |
|
- action_space: Union[spaces.Box, spaces.Discrete] = None, |
|
- reward: Optional[float] = None, |
|
- deterministic: bool = False, |
|
- context_window: Optional[int] = None, |
|
- ): |
|
- pass |
|
\ No newline at end of file |
|
|
|
deleted file mode 100644 |
|
|
|
|
|
|
|
|
|
|
|
@@ -8,23 +8,35 @@ from tqdm import tqdm |
|
from autofaiss import build_index |
|
|
|
|
|
+UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11 |
|
+ |
|
+} |
|
+ |
|
def myprint(str): |
|
- # check if first character of string is a newline character |
|
- if str[0] == '\n': |
|
- str_without_newline = str[1:] |
|
+ # check if first characters of string are newline character |
|
+ num_newlines = 0 |
|
+ while str[num_newlines] == '\n': |
|
print() |
|
- else: |
|
- str_without_newline = str |
|
+ num_newlines += 1 |
|
+ str_without_newline = str[num_newlines:] |
|
print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}') |
|
|
|
def is_png_img(item): |
|
return isinstance(item, PngImagePlugin.PngImageFile) |
|
|
|
+def get_last_row_for_1M_states(task): |
|
+ last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101} |
|
+ return last_row_idx[task] |
|
+ |
|
+def get_last_row_for_100k_states(task): |
|
+ last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407} |
|
+ return last_row_idx[task] |
|
+ |
|
def get_obs_dim(task): |
|
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco") |
|
|
|
all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17} |
|
- return all_obs_dims[task] |
|
+ return (all_obs_dims[task],) |
|
|
|
def get_act_dim(task): |
|
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco") |
|
@@ -36,141 +48,188 @@ def get_act_dim(task): |
|
elif task.startswith("mujoco"): |
|
all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6} |
|
return all_act_dims[task] |
|
- |
|
-def process_row_atari(attn_mask, row_of_obs, task): |
|
- """ |
|
- Example for selection with bools: |
|
- >>> a = np.array([0,1,2,3,4,5]) |
|
- >>> b = np.array([1,0,0,0,0,1]).astype(bool) |
|
- >>> a[b] |
|
- array([0, 5]) |
|
- """ |
|
- attn_mask = np.array(attn_mask).astype(bool) |
|
|
|
- row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs]) |
|
- row_of_obs = row_of_obs[attn_mask] |
|
+def get_task_info(task): |
|
+ rew_key = 'rewards' |
|
+ attn_key = 'attention_mask' |
|
+ if task.startswith("atari"): |
|
+ obs_key = 'image_observations' |
|
+ act_key = 'discrete_actions' |
|
+ B = 32 # half of 54 |
|
+ obs_dim = (3, 4*84, 84) |
|
+ elif task.startswith("babyai"): |
|
+ obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset) |
|
+ act_key = 'discrete_actions' |
|
+ B = 256 # half of 512 |
|
+ obs_dim = get_obs_dim(task) |
|
+ elif task.startswith("metaworld") or task.startswith("mujoco"): |
|
+ obs_key = 'continuous_observations' |
|
+ act_key = 'continuous_actions' |
|
+ B = 256 |
|
+ obs_dim = get_obs_dim(task) |
|
+ |
|
+ return rew_key, attn_key, obs_key, act_key, B, obs_dim |
|
+ |
|
+def process_row_of_obs_atari_full_without_mask(row_of_obs): |
|
+ |
|
+ if not isinstance(row_of_obs, torch.Tensor): |
|
+ row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs]) |
|
row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1] |
|
- assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84) |
|
+ assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84) |
|
row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84) |
|
- row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side |
|
+ row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side |
|
row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels |
|
- assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension |
|
- |
|
- return attn_mask, row_of_obs |
|
+ assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension |
|
+ |
|
+ return row_of_obs |
|
|
|
-def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False): |
|
- attn_mask = np.array(attn_mask).astype(bool) |
|
+def collect_all_atari_data(dataset, all_row_idxs=None): |
|
+ if all_row_idxs is None: |
|
+ all_row_idxs = list(range(len(dataset['train']))) |
|
|
|
- row_of_obs = np.array(row_of_obs) |
|
- if not return_numpy: |
|
- row_of_obs = torch.tensor(row_of_obs) |
|
- row_of_obs = row_of_obs[attn_mask] |
|
- assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task)) |
|
- |
|
- return attn_mask, row_of_obs |
|
- |
|
-def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84) |
|
- dataset, # to retrieve from |
|
- all_rows_to_consider, # rows to consider |
|
- num_to_retrieve, # top-k |
|
+ all_rows_of_obs = [] |
|
+ all_attn_masks = [] |
|
+ for row_idx in tqdm(all_row_idxs): |
|
+ datarow = dataset['train'][row_idx] |
|
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations']) |
|
+ attn_mask = np.array(datarow['attention_mask']).astype(bool) |
|
+ all_rows_of_obs.append(row_of_obs) # appending tensor |
|
+ all_attn_masks.append(attn_mask) # appending np array |
|
+ all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors |
|
+ all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays |
|
+ assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and |
|
+ all_attn_masks.shape == (len(all_row_idxs), 32)) |
|
+ return all_attn_masks, all_rows_of_obs |
|
+ |
|
+def collect_all_data(dataset, task, obs_key): |
|
+ last_row_idx = get_last_row_for_100k_states(task) |
|
+ all_row_idxs = list(range(last_row_idx)) |
|
+ if task.startswith("atari"): |
|
+ myprint("Collecting all Atari images and Atari attention masks...") |
|
+ all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs) |
|
+ else: |
|
+ datarows = dataset['train'][all_row_idxs] |
|
+ all_rows_of_obs_OG = np.array(datarows[obs_key]) |
|
+ all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool) |
|
+ return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs |
|
+ |
|
+def collect_subset(all_rows_of_obs_OG, |
|
+ all_attn_masks_OG, |
|
+ all_rows_to_consider, |
|
+ kwargs |
|
+ ): |
|
+ """ |
|
+ Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return. |
|
+ Used in both retrieve_atari() and retrieve_vector() --> build_index_vector(). |
|
+ """ |
|
+ myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...') |
|
+ # read kwargs |
|
+ B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim'] |
|
+ |
|
+ # take subset based on all_rows_to_consider |
|
+ myprint(f'Taking subset of data based on all_rows_to_consider...') |
|
+ all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider] |
|
+ all_attn_masks = all_attn_masks_OG[all_rows_to_consider] |
|
+ assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and |
|
+ all_attn_masks.shape == (len(all_rows_to_consider), B)) |
|
+ |
|
+ # reshape |
|
+ myprint(f'Reshaping data...') |
|
+ all_attn_masks = all_attn_masks.reshape(-1) |
|
+ all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim) |
|
+ all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks] |
|
+ assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and |
|
+ all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim)) |
|
+ |
|
+ # collect indices of data |
|
+ myprint(f'Collecting indices of data...') |
|
+ all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)]) |
|
+ all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s |
|
+ assert all_indices.shape == (np.sum(all_attn_masks), 2) |
|
+ |
|
+ myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}') |
|
+ myprint(('-'*100) + '\n\n\n') |
|
+ return all_indices, all_processed_rows_of_obs |
|
+ |
|
+def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim) |
|
+ all_processed_rows_of_obs, |
|
+ all_indices, |
|
+ num_to_retrieve, |
|
kwargs |
|
- ): |
|
+ ): |
|
+ """ |
|
+ Retrieval for Atari with images, ssim distance, and on GPU. |
|
+ """ |
|
assert isinstance(row_of_obs, torch.Tensor) |
|
|
|
# read kwargs # Note: B = len of row |
|
- B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'] |
|
+ B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval'] |
|
|
|
# batch size of row_of_obs which can be <= B since we process before calling this function |
|
- row_B = row_of_obs.shape[0] |
|
- |
|
+ xbdim = row_of_obs.shape[0] |
|
+ |
|
+ # collect subset of data that we can retrieve from |
|
+ ydim = all_processed_rows_of_obs.shape[0] |
|
+ |
|
# first argument for ssim |
|
- repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device) |
|
- assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84) |
|
+ xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device) |
|
+ assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84) |
|
|
|
- # iterate over all other rows |
|
+ # iterate over data that we can retrieve from in batches |
|
all_ssim = [] |
|
- all_indices = [] |
|
- total = 0 |
|
- for other_row_idx in tqdm(all_rows_to_consider): |
|
- other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key]) |
|
- |
|
- # batch size of other_row_of_obs |
|
- other_row_B = other_row_of_obs.shape[0] |
|
- total += other_row_B |
|
- |
|
- # first argument for ssim: RECHECK |
|
- if other_row_B < B: # when other row has less observations than expected |
|
- repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device) |
|
- elif other_row_B == B: # otherwise just use the one created before the for loop |
|
- repeated_row = repeated_row_og |
|
- assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84) |
|
- |
|
+ for j in range(0, ydim, batch_size_retrieval): |
|
# second argument for ssim |
|
- repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device) |
|
- assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84) |
|
+ ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval] |
|
+ ybdim = ybatch.shape[0] |
|
+ ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device) |
|
+ assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84) |
|
+ |
|
+ if ybdim < batch_size_retrieval: # for last batch |
|
+ xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device) |
|
+ assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84) |
|
|
|
# compare via ssim and updated all_ssim |
|
- ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False) |
|
- ssim_score = ssim_score.reshape(row_B, other_row_B) |
|
+ ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False) |
|
+ ssim_score = ssim_score.reshape(xbdim, ybdim) |
|
all_ssim.append(ssim_score) |
|
|
|
- # update all_indices |
|
- all_indices.extend([[other_row_idx, i] for i in range(other_row_B)]) |
|
- |
|
# concat |
|
all_ssim = torch.cat(all_ssim, dim=1) |
|
- assert all_ssim.shape == (row_B, total) |
|
+ assert all_ssim.shape == (xbdim, ydim) |
|
|
|
- all_indices = np.array(all_indices) |
|
- assert all_indices.shape == (total, 2) |
|
+ assert all_indices.shape == (ydim, 2) |
|
|
|
# get top-k indices |
|
topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True) |
|
topk_indices = topk_indices.cpu().numpy() |
|
- assert topk_indices.shape == (row_B, num_to_retrieve) |
|
+ assert topk_indices.shape == (xbdim, num_to_retrieve) |
|
|
|
# convert topk indices to indices in the dataset |
|
- retrieved_indices = np.array(all_indices[topk_indices]) |
|
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2) |
|
- |
|
- # pad the above to expected B |
|
- if row_B < B: |
|
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0) |
|
- assert retrieved_indices.shape == (B, num_to_retrieve, 2) |
|
+ retrieved_indices = all_indices[topk_indices] |
|
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2) |
|
|
|
return retrieved_indices |
|
|
|
-def build_index_vector(all_rows_of_obs_og, |
|
- all_attn_masks_og, |
|
+def build_index_vector(all_rows_of_obs_OG, |
|
+ all_attn_masks_OG, |
|
all_rows_to_consider, |
|
kwargs |
|
- ): |
|
+ ): |
|
+ """ |
|
+ Builds FAISS index for vector observation environments. |
|
+ """ |
|
# read kwargs # Note: B = len of row |
|
- B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss'] |
|
- obs_dim = get_obs_dim(task) |
|
+ nb_cores_autofaiss = kwargs['nb_cores_autofaiss'] |
|
|
|
- # take subset based on all_rows_to_consider |
|
- myprint(f'Taking subset') |
|
- all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider] |
|
- all_attn_masks = all_attn_masks_og[all_rows_to_consider] |
|
- assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and |
|
- all_attn_masks.shape == (len(all_rows_to_consider), B)) |
|
- |
|
- # reshape |
|
- all_attn_masks = all_attn_masks.reshape(-1) |
|
- all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim) |
|
- all_rows_of_obs = all_rows_of_obs[all_attn_masks] |
|
- assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim) |
|
+ # take subset based on all_rows_to_consider, reshape, and save indices of data |
|
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs) |
|
|
|
- # save indices of data to retrieve from |
|
- myprint(f'Saving indices of data to retrieve from') |
|
- all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)]) |
|
- all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s |
|
- assert all_indices.shape == (np.sum(all_attn_masks), 2) |
|
+ # make sure input to build_index is float, otherwise you will get reading temp file error |
|
+ all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float) |
|
|
|
# build index |
|
- myprint(f'Building index...') |
|
- knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader! |
|
+ myprint(('-'*100) + 'Building index...') |
|
+ knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader! |
|
save_on_disk=False, |
|
min_nearest_neighbors_to_retrieve=20, # default: 20 |
|
max_index_query_time_ms=10, # default: 10 |
|
@@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og, |
|
metric_type='l2', |
|
nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command |
|
) |
|
+ myprint(('-'*100) + '\n\n\n') |
|
|
|
- return knn_index, all_indices |
|
+ return all_indices, knn_index |
|
|
|
-def retrieve_vector(row_of_obs, # query: (row_B, dim) |
|
- dataset, # to retrieve from |
|
- all_rows_to_consider, # rows to consider |
|
- num_to_retrieve, # top-k |
|
+def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim) |
|
+ knn_index, |
|
+ all_indices, |
|
+ num_to_retrieve, |
|
kwargs |
|
- ): |
|
+ ): |
|
+ """ |
|
+ Retrieval for vector observation environments. |
|
+ """ |
|
assert isinstance(row_of_obs, np.ndarray) |
|
|
|
# read few kwargs |
|
B = kwargs['B'] |
|
|
|
# batch size of row_of_obs which can be <= B since we process before calling this function |
|
- row_B = row_of_obs.shape[0] |
|
+ xbdim = row_of_obs.shape[0] |
|
|
|
- # read dataset_tuple |
|
- all_rows_of_obs, all_attn_masks = dataset |
|
- |
|
- # create index and all_indices |
|
- knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs) |
|
- |
|
# retrieve |
|
myprint(f'Retrieving...') |
|
topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve) |
|
topk_indices = topk_indices.astype(int) |
|
- assert topk_indices.shape == (row_B, 10 * num_to_retrieve) |
|
+ assert topk_indices.shape == (xbdim, 10 * num_to_retrieve) |
|
|
|
# remove -1s and crop to num_to_retrieve |
|
try: |
|
@@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim) |
|
print(f'-------------------------------------------------------------------------------------------------------------------------------------------') |
|
print(f'Leaving some -1s in topk_indices and continuing') |
|
topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices]) |
|
- assert topk_indices.shape == (row_B, num_to_retrieve) |
|
+ assert topk_indices.shape == (xbdim, num_to_retrieve) |
|
|
|
# convert topk indices to indices in the dataset |
|
retrieved_indices = all_indices[topk_indices] |
|
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2) |
|
- |
|
- # pad the above to expected B |
|
- if row_B < B: |
|
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0) |
|
- assert retrieved_indices.shape == (B, num_to_retrieve, 2) |
|
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2) |
|
|
|
- myprint(f'Returning') |
|
return retrieved_indices |
|
\ No newline at end of file |
|
|
|
|
|
|
|
|
|
@@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser |
|
|
|
from jat.eval.rl import TASK_NAME_TO_ENV_ID, make |
|
from jat.utils import normalize, push_to_hub, save_video_grid |
|
-from jat_regent.RandP import RandP |
|
+from jat_regent.modeling_RandP import RandP |
|
from datasets import load_from_disk |
|
from datasets.config import HF_DATASETS_CACHE |
|
+from jat_regent.utils import myprint |
|
|
|
|
|
@dataclass |
|
@@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args): |
|
scores = [] |
|
frames = [] |
|
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False): |
|
+ myprint(('-'*100) + f'{episode=}') |
|
observation, _ = env.reset() |
|
reward = None |
|
rewards = [] |
|
@@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args): |
|
frames.append(np.array(env.render(), dtype=np.uint8)) |
|
|
|
scores.append(sum(rewards)) |
|
+ myprint(('-'*100) + '\n\n\n') |
|
env.close() |
|
|
|
raw_mean, raw_std = np.mean(scores), np.std(scores) |
|
@@ -145,7 +148,9 @@ def main(): |
|
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) |
|
|
|
device = torch.device("cpu") if eval_args.use_cpu else get_default_device() |
|
- processor = None |
|
+ processor = AutoProcessor.from_pretrained( |
|
+ 'jat-project/jat', cache_dir=None, trust_remote_code=True |
|
+ ) |
|
|
|
evaluations = {} |
|
video_list = [] |
|
@@ -153,14 +158,18 @@ def main(): |
|
|
|
for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True): |
|
if task in TASK_NAME_TO_ENV_ID.keys(): |
|
+ myprint(('-'*100) + f'{task=}') |
|
dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}') |
|
- model = RandP(dataset) |
|
+ model = RandP(task, |
|
+ dataset, |
|
+ device,) |
|
scores, frames, fps = eval_rl(model, processor, task, eval_args) |
|
evaluations[task] = scores |
|
# Save the video |
|
if eval_args.save_video: |
|
video_list.append(frames) |
|
input_fps.append(fps) |
|
+ myprint(('-'*100) + '\n\n\n') |
|
else: |
|
warnings.warn(f"Task {task} is not supported.") |
|
|
|
|
|
|
|
|
|
|
|
@@ -8,7 +8,7 @@ import time |
|
from datetime import datetime |
|
from datasets import load_from_disk |
|
from datasets.config import HF_DATASETS_CACHE |
|
-from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector |
|
+from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector |
|
import logging |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
@@ -17,7 +17,8 @@ def main(): |
|
parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices') |
|
parser.add_argument('--task', type=str, default='atari-alien', help='Task name') |
|
parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve') |
|
- parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments') |
|
+ parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs') |
|
+ parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari') |
|
args = parser.parse_args() |
|
|
|
# load dataset, map, device, for task |
|
@@ -25,77 +26,83 @@ def main(): |
|
dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
- rew_key = 'rewards' |
|
- attn_key = 'attention_mask' |
|
- if task.startswith("atari"): |
|
- obs_key = 'image_observations' |
|
- act_key = 'discrete_actions' |
|
- len_row_tokenized_known = 32 # half of 54 |
|
- process_row_fn = process_row_atari |
|
- retrieve_fn = retrieve_atari |
|
- elif task.startswith("babyai"): |
|
- obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset) |
|
- act_key = 'discrete_actions' |
|
- len_row_tokenized_known = 256 # half of 512 |
|
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True) |
|
- retrieve_fn = retrieve_vector |
|
- elif task.startswith("metaworld") or task.startswith("mujoco"): |
|
- obs_key = 'continuous_observations' |
|
- act_key = 'continuous_actions' |
|
- len_row_tokenized_known = 256 |
|
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True) |
|
- retrieve_fn = retrieve_vector |
|
+ rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task) |
|
|
|
dataset = load_from_disk(dataset_path) |
|
with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f: |
|
map_from_rows_to_episodes_for_tokenized = json.load(f) |
|
|
|
# setup kwargs |
|
- len_dataset = len(dataset['train']) |
|
- B = len_row_tokenized_known |
|
kwargs = {'B': B, |
|
- 'attn_key':attn_key, |
|
- 'obs_key':obs_key, |
|
- 'device':device, |
|
- 'task':task, |
|
- 'batch_size_retrieval':None, |
|
- 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss, |
|
- } |
|
+ 'obs_dim': obs_dim, |
|
+ 'attn_key': attn_key, |
|
+ 'obs_key': obs_key, |
|
+ 'device': device, |
|
+ 'task': task, |
|
+ 'batch_size_retrieval': args.batch_size_retrieval, |
|
+ 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss, |
|
+ } |
|
|
|
# collect all observations in a single array (this takes some time) for vector observation environments |
|
- if not task.startswith("atari"): |
|
- myprint("Collecting all observations/attn_masks in a single array") |
|
- all_rows_of_obs = np.array(dataset['train'][obs_key]) |
|
- all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool) |
|
+ myprint("Collecting all observations/attn_masks") |
|
+ all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key) |
|
|
|
# iterate over rows |
|
all_retrieved_indices = [] |
|
- for row_idx in range(len_dataset): |
|
- myprint(f"\nProcessing row {row_idx}/{len_dataset}") |
|
+ for row_idx in all_row_idxs: |
|
+ myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}") |
|
current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)] |
|
|
|
- attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task) |
|
+ # get row_of_obs and attn_mask |
|
+ datarow = dataset['train'][row_idx] |
|
+ attn_mask = np.array(datarow[attn_key]).astype(bool) |
|
+ if task.startswith("atari"): |
|
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key]) |
|
+ else: |
|
+ row_of_obs = np.array(datarow[obs_key]) |
|
+ row_of_obs = row_of_obs[attn_mask] |
|
+ assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim) |
|
|
|
# compare with rows from all but the current episode |
|
- all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep] |
|
+ all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep] |
|
|
|
# do the retrieval |
|
- retrieved_indices = retrieve_fn(row_of_obs=row_of_obs, |
|
- dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks), |
|
- all_rows_to_consider=all_other_rows, |
|
- num_to_retrieve=args.num_to_retrieve, |
|
- kwargs=kwargs, |
|
- ) |
|
+ if task.startswith("atari"): |
|
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG, |
|
+ all_attn_masks_OG=all_attn_masks_OG, |
|
+ all_rows_to_consider=all_row_idxs, |
|
+ kwargs=kwargs) |
|
+ retrieved_indices = retrieve_atari(row_of_obs=row_of_obs, |
|
+ all_processed_rows_of_obs=all_processed_rows_of_obs, |
|
+ all_indices=all_indices, |
|
+ num_to_retrieve=args.num_to_retrieve, |
|
+ kwargs=kwargs) |
|
+ else: |
|
+ all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG, |
|
+ all_attn_masks_OG=all_attn_masks_OG, |
|
+ all_rows_to_consider=all_other_row_idxs, |
|
+ kwargs=kwargs) |
|
+ retrieved_indices = retrieve_vector(row_of_obs=row_of_obs, |
|
+ knn_index=knn_index, |
|
+ all_indices=all_indices, |
|
+ num_to_retrieve=args.num_to_retrieve, |
|
+ kwargs=kwargs) |
|
+ |
|
+ # pad the above to expected B |
|
+ xbdim = row_of_obs.shape[0] |
|
+ if xbdim < B: |
|
+ retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0) |
|
+ assert retrieved_indices.shape == (B, args.num_to_retrieve, 2) |
|
|
|
# collect retrieved indices |
|
all_retrieved_indices.append(retrieved_indices) |
|
|
|
# concat |
|
all_retrieved_indices = np.stack(all_retrieved_indices, axis=0) |
|
- assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2) |
|
+ assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2) |
|
|
|
# save arrays as bin for easy memmap access and faster loading |
|
- all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin") |
|
+ all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
\ No newline at end of file |
|
|