ksridhar's picture
Upload folder using huggingface_hub
d007384 verified
raw
history blame
44.7 kB
diff --git a/README.md b/README.md
index e51a12b..a6e1ca1 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,21 @@ conda activate jat
pip install -e .[dev]
```
+## REGENT fork of sample-factory: Installation
+Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork:
+```shell
+git clone https://github.com/kaustubhsridhar/sample-factory.git
+cd sample-factory
+pip install -e .[dev,mujoco,atari,envpool,vizdoom]
+```
+
+# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets
+Train policies using envpool's atari:
+```shell
+bash scripts_sample-factory/train_unseen_atari.sh
+```
+Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2).
+
## PREV Installation
To get started with JAT, follow these steps:
@@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS
```
### REGENT Analyze data
+Necessary:
```shell
-python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
-
python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt &
+```
+Already ran and output dict in code:
+```shell
python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt &
+
+python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt &
+```
+
+Optional:
+```shell
+python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
```
## PREV Dataset
diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py
deleted file mode 100644
index b2bd8bf..0000000
--- a/jat_regent/RandP.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import warnings
-from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from gymnasium import spaces
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
-from transformers import GPTNeoModel, GPTNeoPreTrainedModel
-from transformers.modeling_outputs import ModelOutput
-from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
-
-from jat.configuration_jat import JatConfig
-from jat.processing_jat import JatProcessor
-
-
-class RandP():
- def __init__(self, dataset) -> None:
- self.steps = 0
- # create an index for retrieval in vector obs envs (OR) collect all images in Atari
-
- def reset_rl(self):
- self.steps = 0
-
- def get_next_action(
- self,
- processor: JatProcessor,
- continuous_observation: Optional[List[float]] = None,
- discrete_observation: Optional[List[int]] = None,
- text_observation: Optional[str] = None,
- image_observation: Optional[np.ndarray] = None,
- action_space: Union[spaces.Box, spaces.Discrete] = None,
- reward: Optional[float] = None,
- deterministic: bool = False,
- context_window: Optional[int] = None,
- ):
- pass
\ No newline at end of file
diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py
deleted file mode 100644
index e69de29..0000000
diff --git a/jat_regent/utils.py b/jat_regent/utils.py
index 56bfb44..36f6cca 100644
--- a/jat_regent/utils.py
+++ b/jat_regent/utils.py
@@ -8,23 +8,35 @@ from tqdm import tqdm
from autofaiss import build_index
+UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11
+
+}
+
def myprint(str):
- # check if first character of string is a newline character
- if str[0] == '\n':
- str_without_newline = str[1:]
+ # check if first characters of string are newline character
+ num_newlines = 0
+ while str[num_newlines] == '\n':
print()
- else:
- str_without_newline = str
+ num_newlines += 1
+ str_without_newline = str[num_newlines:]
print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}')
def is_png_img(item):
return isinstance(item, PngImagePlugin.PngImageFile)
+def get_last_row_for_1M_states(task):
+ last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101}
+ return last_row_idx[task]
+
+def get_last_row_for_100k_states(task):
+ last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407}
+ return last_row_idx[task]
+
def get_obs_dim(task):
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17}
- return all_obs_dims[task]
+ return (all_obs_dims[task],)
def get_act_dim(task):
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
@@ -36,141 +48,188 @@ def get_act_dim(task):
elif task.startswith("mujoco"):
all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6}
return all_act_dims[task]
-
-def process_row_atari(attn_mask, row_of_obs, task):
- """
- Example for selection with bools:
- >>> a = np.array([0,1,2,3,4,5])
- >>> b = np.array([1,0,0,0,0,1]).astype(bool)
- >>> a[b]
- array([0, 5])
- """
- attn_mask = np.array(attn_mask).astype(bool)
- row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
- row_of_obs = row_of_obs[attn_mask]
+def get_task_info(task):
+ rew_key = 'rewards'
+ attn_key = 'attention_mask'
+ if task.startswith("atari"):
+ obs_key = 'image_observations'
+ act_key = 'discrete_actions'
+ B = 32 # half of 54
+ obs_dim = (3, 4*84, 84)
+ elif task.startswith("babyai"):
+ obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
+ act_key = 'discrete_actions'
+ B = 256 # half of 512
+ obs_dim = get_obs_dim(task)
+ elif task.startswith("metaworld") or task.startswith("mujoco"):
+ obs_key = 'continuous_observations'
+ act_key = 'continuous_actions'
+ B = 256
+ obs_dim = get_obs_dim(task)
+
+ return rew_key, attn_key, obs_key, act_key, B, obs_dim
+
+def process_row_of_obs_atari_full_without_mask(row_of_obs):
+
+ if not isinstance(row_of_obs, torch.Tensor):
+ row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1]
- assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84)
+ assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84)
row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84)
- row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side
+ row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side
row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels
- assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
-
- return attn_mask, row_of_obs
+ assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
+
+ return row_of_obs
-def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False):
- attn_mask = np.array(attn_mask).astype(bool)
+def collect_all_atari_data(dataset, all_row_idxs=None):
+ if all_row_idxs is None:
+ all_row_idxs = list(range(len(dataset['train'])))
- row_of_obs = np.array(row_of_obs)
- if not return_numpy:
- row_of_obs = torch.tensor(row_of_obs)
- row_of_obs = row_of_obs[attn_mask]
- assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task))
-
- return attn_mask, row_of_obs
-
-def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84)
- dataset, # to retrieve from
- all_rows_to_consider, # rows to consider
- num_to_retrieve, # top-k
+ all_rows_of_obs = []
+ all_attn_masks = []
+ for row_idx in tqdm(all_row_idxs):
+ datarow = dataset['train'][row_idx]
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations'])
+ attn_mask = np.array(datarow['attention_mask']).astype(bool)
+ all_rows_of_obs.append(row_of_obs) # appending tensor
+ all_attn_masks.append(attn_mask) # appending np array
+ all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors
+ all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays
+ assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and
+ all_attn_masks.shape == (len(all_row_idxs), 32))
+ return all_attn_masks, all_rows_of_obs
+
+def collect_all_data(dataset, task, obs_key):
+ last_row_idx = get_last_row_for_100k_states(task)
+ all_row_idxs = list(range(last_row_idx))
+ if task.startswith("atari"):
+ myprint("Collecting all Atari images and Atari attention masks...")
+ all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs)
+ else:
+ datarows = dataset['train'][all_row_idxs]
+ all_rows_of_obs_OG = np.array(datarows[obs_key])
+ all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool)
+ return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs
+
+def collect_subset(all_rows_of_obs_OG,
+ all_attn_masks_OG,
+ all_rows_to_consider,
+ kwargs
+ ):
+ """
+ Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return.
+ Used in both retrieve_atari() and retrieve_vector() --> build_index_vector().
+ """
+ myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...')
+ # read kwargs
+ B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim']
+
+ # take subset based on all_rows_to_consider
+ myprint(f'Taking subset of data based on all_rows_to_consider...')
+ all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider]
+ all_attn_masks = all_attn_masks_OG[all_rows_to_consider]
+ assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and
+ all_attn_masks.shape == (len(all_rows_to_consider), B))
+
+ # reshape
+ myprint(f'Reshaping data...')
+ all_attn_masks = all_attn_masks.reshape(-1)
+ all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim)
+ all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks]
+ assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and
+ all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim))
+
+ # collect indices of data
+ myprint(f'Collecting indices of data...')
+ all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
+ all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
+ assert all_indices.shape == (np.sum(all_attn_masks), 2)
+
+ myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}')
+ myprint(('-'*100) + '\n\n\n')
+ return all_indices, all_processed_rows_of_obs
+
+def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim)
+ all_processed_rows_of_obs,
+ all_indices,
+ num_to_retrieve,
kwargs
- ):
+ ):
+ """
+ Retrieval for Atari with images, ssim distance, and on GPU.
+ """
assert isinstance(row_of_obs, torch.Tensor)
# read kwargs # Note: B = len of row
- B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval']
+ B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval']
# batch size of row_of_obs which can be <= B since we process before calling this function
- row_B = row_of_obs.shape[0]
-
+ xbdim = row_of_obs.shape[0]
+
+ # collect subset of data that we can retrieve from
+ ydim = all_processed_rows_of_obs.shape[0]
+
# first argument for ssim
- repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device)
- assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84)
+ xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device)
+ assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84)
- # iterate over all other rows
+ # iterate over data that we can retrieve from in batches
all_ssim = []
- all_indices = []
- total = 0
- for other_row_idx in tqdm(all_rows_to_consider):
- other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key])
-
- # batch size of other_row_of_obs
- other_row_B = other_row_of_obs.shape[0]
- total += other_row_B
-
- # first argument for ssim: RECHECK
- if other_row_B < B: # when other row has less observations than expected
- repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device)
- elif other_row_B == B: # otherwise just use the one created before the for loop
- repeated_row = repeated_row_og
- assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84)
-
+ for j in range(0, ydim, batch_size_retrieval):
# second argument for ssim
- repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device)
- assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84)
+ ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval]
+ ybdim = ybatch.shape[0]
+ ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device)
+ assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84)
+
+ if ybdim < batch_size_retrieval: # for last batch
+ xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device)
+ assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84)
# compare via ssim and updated all_ssim
- ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False)
- ssim_score = ssim_score.reshape(row_B, other_row_B)
+ ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False)
+ ssim_score = ssim_score.reshape(xbdim, ybdim)
all_ssim.append(ssim_score)
- # update all_indices
- all_indices.extend([[other_row_idx, i] for i in range(other_row_B)])
-
# concat
all_ssim = torch.cat(all_ssim, dim=1)
- assert all_ssim.shape == (row_B, total)
+ assert all_ssim.shape == (xbdim, ydim)
- all_indices = np.array(all_indices)
- assert all_indices.shape == (total, 2)
+ assert all_indices.shape == (ydim, 2)
# get top-k indices
topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True)
topk_indices = topk_indices.cpu().numpy()
- assert topk_indices.shape == (row_B, num_to_retrieve)
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
# convert topk indices to indices in the dataset
- retrieved_indices = np.array(all_indices[topk_indices])
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
-
- # pad the above to expected B
- if row_B < B:
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
+ retrieved_indices = all_indices[topk_indices]
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
return retrieved_indices
-def build_index_vector(all_rows_of_obs_og,
- all_attn_masks_og,
+def build_index_vector(all_rows_of_obs_OG,
+ all_attn_masks_OG,
all_rows_to_consider,
kwargs
- ):
+ ):
+ """
+ Builds FAISS index for vector observation environments.
+ """
# read kwargs # Note: B = len of row
- B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss']
- obs_dim = get_obs_dim(task)
+ nb_cores_autofaiss = kwargs['nb_cores_autofaiss']
- # take subset based on all_rows_to_consider
- myprint(f'Taking subset')
- all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider]
- all_attn_masks = all_attn_masks_og[all_rows_to_consider]
- assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and
- all_attn_masks.shape == (len(all_rows_to_consider), B))
-
- # reshape
- all_attn_masks = all_attn_masks.reshape(-1)
- all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim)
- all_rows_of_obs = all_rows_of_obs[all_attn_masks]
- assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim)
+ # take subset based on all_rows_to_consider, reshape, and save indices of data
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs)
- # save indices of data to retrieve from
- myprint(f'Saving indices of data to retrieve from')
- all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
- all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
- assert all_indices.shape == (np.sum(all_attn_masks), 2)
+ # make sure input to build_index is float, otherwise you will get reading temp file error
+ all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float)
# build index
- myprint(f'Building index...')
- knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
+ myprint(('-'*100) + 'Building index...')
+ knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
save_on_disk=False,
min_nearest_neighbors_to_retrieve=20, # default: 20
max_index_query_time_ms=10, # default: 10
@@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og,
metric_type='l2',
nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command
)
+ myprint(('-'*100) + '\n\n\n')
- return knn_index, all_indices
+ return all_indices, knn_index
-def retrieve_vector(row_of_obs, # query: (row_B, dim)
- dataset, # to retrieve from
- all_rows_to_consider, # rows to consider
- num_to_retrieve, # top-k
+def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim)
+ knn_index,
+ all_indices,
+ num_to_retrieve,
kwargs
- ):
+ ):
+ """
+ Retrieval for vector observation environments.
+ """
assert isinstance(row_of_obs, np.ndarray)
# read few kwargs
B = kwargs['B']
# batch size of row_of_obs which can be <= B since we process before calling this function
- row_B = row_of_obs.shape[0]
+ xbdim = row_of_obs.shape[0]
- # read dataset_tuple
- all_rows_of_obs, all_attn_masks = dataset
-
- # create index and all_indices
- knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs)
-
# retrieve
myprint(f'Retrieving...')
topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve)
topk_indices = topk_indices.astype(int)
- assert topk_indices.shape == (row_B, 10 * num_to_retrieve)
+ assert topk_indices.shape == (xbdim, 10 * num_to_retrieve)
# remove -1s and crop to num_to_retrieve
try:
@@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim)
print(f'-------------------------------------------------------------------------------------------------------------------------------------------')
print(f'Leaving some -1s in topk_indices and continuing')
topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices])
- assert topk_indices.shape == (row_B, num_to_retrieve)
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
# convert topk indices to indices in the dataset
retrieved_indices = all_indices[topk_indices]
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
-
- # pad the above to expected B
- if row_B < B:
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
- myprint(f'Returning')
return retrieved_indices
\ No newline at end of file
diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py
index 07e545c..146b347 100755
--- a/scripts_regent/eval_RandP.py
+++ b/scripts_regent/eval_RandP.py
@@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser
from jat.eval.rl import TASK_NAME_TO_ENV_ID, make
from jat.utils import normalize, push_to_hub, save_video_grid
-from jat_regent.RandP import RandP
+from jat_regent.modeling_RandP import RandP
from datasets import load_from_disk
from datasets.config import HF_DATASETS_CACHE
+from jat_regent.utils import myprint
@dataclass
@@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args):
scores = []
frames = []
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
+ myprint(('-'*100) + f'{episode=}')
observation, _ = env.reset()
reward = None
rewards = []
@@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args):
frames.append(np.array(env.render(), dtype=np.uint8))
scores.append(sum(rewards))
+ myprint(('-'*100) + '\n\n\n')
env.close()
raw_mean, raw_std = np.mean(scores), np.std(scores)
@@ -145,7 +148,9 @@ def main():
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
- processor = None
+ processor = AutoProcessor.from_pretrained(
+ 'jat-project/jat', cache_dir=None, trust_remote_code=True
+ )
evaluations = {}
video_list = []
@@ -153,14 +158,18 @@ def main():
for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
if task in TASK_NAME_TO_ENV_ID.keys():
+ myprint(('-'*100) + f'{task=}')
dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
- model = RandP(dataset)
+ model = RandP(task,
+ dataset,
+ device,)
scores, frames, fps = eval_rl(model, processor, task, eval_args)
evaluations[task] = scores
# Save the video
if eval_args.save_video:
video_list.append(frames)
input_fps.append(fps)
+ myprint(('-'*100) + '\n\n\n')
else:
warnings.warn(f"Task {task} is not supported.")
diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py
index c83d259..aad678a 100644
--- a/scripts_regent/offline_retrieval_jat_regent.py
+++ b/scripts_regent/offline_retrieval_jat_regent.py
@@ -8,7 +8,7 @@ import time
from datetime import datetime
from datasets import load_from_disk
from datasets.config import HF_DATASETS_CACHE
-from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector
+from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector
import logging
logging.basicConfig(level=logging.DEBUG)
@@ -17,7 +17,8 @@ def main():
parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices')
parser.add_argument('--task', type=str, default='atari-alien', help='Task name')
parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve')
- parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments')
+ parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs')
+ parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari')
args = parser.parse_args()
# load dataset, map, device, for task
@@ -25,77 +26,83 @@ def main():
dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- rew_key = 'rewards'
- attn_key = 'attention_mask'
- if task.startswith("atari"):
- obs_key = 'image_observations'
- act_key = 'discrete_actions'
- len_row_tokenized_known = 32 # half of 54
- process_row_fn = process_row_atari
- retrieve_fn = retrieve_atari
- elif task.startswith("babyai"):
- obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
- act_key = 'discrete_actions'
- len_row_tokenized_known = 256 # half of 512
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
- retrieve_fn = retrieve_vector
- elif task.startswith("metaworld") or task.startswith("mujoco"):
- obs_key = 'continuous_observations'
- act_key = 'continuous_actions'
- len_row_tokenized_known = 256
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
- retrieve_fn = retrieve_vector
+ rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task)
dataset = load_from_disk(dataset_path)
with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f:
map_from_rows_to_episodes_for_tokenized = json.load(f)
# setup kwargs
- len_dataset = len(dataset['train'])
- B = len_row_tokenized_known
kwargs = {'B': B,
- 'attn_key':attn_key,
- 'obs_key':obs_key,
- 'device':device,
- 'task':task,
- 'batch_size_retrieval':None,
- 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss,
- }
+ 'obs_dim': obs_dim,
+ 'attn_key': attn_key,
+ 'obs_key': obs_key,
+ 'device': device,
+ 'task': task,
+ 'batch_size_retrieval': args.batch_size_retrieval,
+ 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss,
+ }
# collect all observations in a single array (this takes some time) for vector observation environments
- if not task.startswith("atari"):
- myprint("Collecting all observations/attn_masks in a single array")
- all_rows_of_obs = np.array(dataset['train'][obs_key])
- all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool)
+ myprint("Collecting all observations/attn_masks")
+ all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key)
# iterate over rows
all_retrieved_indices = []
- for row_idx in range(len_dataset):
- myprint(f"\nProcessing row {row_idx}/{len_dataset}")
+ for row_idx in all_row_idxs:
+ myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}")
current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)]
- attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task)
+ # get row_of_obs and attn_mask
+ datarow = dataset['train'][row_idx]
+ attn_mask = np.array(datarow[attn_key]).astype(bool)
+ if task.startswith("atari"):
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key])
+ else:
+ row_of_obs = np.array(datarow[obs_key])
+ row_of_obs = row_of_obs[attn_mask]
+ assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim)
# compare with rows from all but the current episode
- all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
+ all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
# do the retrieval
- retrieved_indices = retrieve_fn(row_of_obs=row_of_obs,
- dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks),
- all_rows_to_consider=all_other_rows,
- num_to_retrieve=args.num_to_retrieve,
- kwargs=kwargs,
- )
+ if task.startswith("atari"):
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG,
+ all_attn_masks_OG=all_attn_masks_OG,
+ all_rows_to_consider=all_row_idxs,
+ kwargs=kwargs)
+ retrieved_indices = retrieve_atari(row_of_obs=row_of_obs,
+ all_processed_rows_of_obs=all_processed_rows_of_obs,
+ all_indices=all_indices,
+ num_to_retrieve=args.num_to_retrieve,
+ kwargs=kwargs)
+ else:
+ all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
+ all_attn_masks_OG=all_attn_masks_OG,
+ all_rows_to_consider=all_other_row_idxs,
+ kwargs=kwargs)
+ retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
+ knn_index=knn_index,
+ all_indices=all_indices,
+ num_to_retrieve=args.num_to_retrieve,
+ kwargs=kwargs)
+
+ # pad the above to expected B
+ xbdim = row_of_obs.shape[0]
+ if xbdim < B:
+ retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0)
+ assert retrieved_indices.shape == (B, args.num_to_retrieve, 2)
# collect retrieved indices
all_retrieved_indices.append(retrieved_indices)
# concat
all_retrieved_indices = np.stack(all_retrieved_indices, axis=0)
- assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2)
+ assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2)
# save arrays as bin for easy memmap access and faster loading
- all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin")
+ all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin")
if __name__ == "__main__":
main()
\ No newline at end of file