Spaces:
Runtime error
Runtime error
"""Data collection script.""" | |
import os | |
import hydra | |
import numpy as np | |
import random | |
from cliport import tasks | |
from cliport.dataset import RavensDataset | |
from cliport.environments.environment import Environment | |
import IPython | |
import random | |
def main(cfg): | |
# Initialize environment and task. | |
env = Environment( | |
cfg['assets_root'], | |
disp=cfg['disp'], | |
shared_memory=cfg['shared_memory'], | |
hz=480, | |
record_cfg=cfg['record'] | |
) | |
task = tasks.names[cfg['task']]() | |
task.mode = cfg['mode'] | |
record = cfg['record']['save_video'] | |
save_data = cfg['save_data'] | |
# Initialize scripted oracle agent and dataset. | |
agent = task.oracle(env) | |
data_path = os.path.join(cfg['data_dir'], "{}-{}".format(cfg['task'], task.mode)) | |
dataset = RavensDataset(data_path, cfg, n_demos=0, augment=False) | |
print(f"Saving to: {data_path}") | |
print(f"Mode: {task.mode}") | |
# Train seeds are even and val/test seeds are odd. Test seeds are offset by 10000 | |
seed = dataset.max_seed | |
max_eps = 3 * cfg['n'] | |
if seed < 0: | |
if task.mode == 'train': | |
seed = -2 | |
elif task.mode == 'val': # NOTE: beware of increasing val set to >100 | |
seed = -1 | |
elif task.mode == 'test': | |
seed = -1 + 10000 | |
else: | |
raise Exception("Invalid mode. Valid options: train, val, test") | |
if 'regenerate_data' in cfg: | |
dataset.n_episodes = 0 | |
curr_run_eps = 0 | |
# Collect training data from oracle demonstrations. | |
while dataset.n_episodes < cfg['n'] and curr_run_eps < max_eps: | |
# for epi_idx in range(cfg['n']): | |
episode, total_reward = [], 0 | |
seed += 2 | |
# Set seeds. | |
np.random.seed(seed) | |
random.seed(seed) | |
print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, cfg['n'], seed)) | |
try: | |
curr_run_eps += 1 # make sure exits the loop | |
env.set_task(task) | |
obs = env.reset() | |
info = env.info | |
reward = 0 | |
# Unlikely, but a safety check to prevent leaks. | |
if task.mode == 'val' and seed > (-1 + 10000): | |
raise Exception("!!! Seeds for val set will overlap with the test set !!!") | |
# Start video recording (NOTE: super slow) | |
if record: | |
env.start_rec(f'{dataset.n_episodes+1:06d}') | |
# Rollout expert policy | |
for _ in range(task.max_steps): | |
act = agent.act(obs, info) | |
episode.append((obs, act, reward, info)) | |
lang_goal = info['lang_goal'] | |
obs, reward, done, info = env.step(act) | |
total_reward += reward | |
print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') | |
if done: | |
break | |
if record: | |
env.end_rec() | |
except Exception as e: | |
from pygments import highlight | |
from pygments.lexers import PythonLexer | |
from pygments.formatters import TerminalFormatter | |
import traceback | |
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) | |
print(to_print) | |
if record: | |
env.end_rec() | |
continue | |
episode.append((obs, None, reward, info)) | |
# Only save completed demonstrations. | |
if save_data and total_reward > 0.99: | |
dataset.add(seed, episode) | |
if hasattr(env, 'blender_recorder'): | |
print("blender pickle saved to ", '{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes)) | |
env.blender_recorder.save('{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes)) | |
if __name__ == '__main__': | |
main() | |