Spaces:
Runtime error
Runtime error
"""Main training script.""" | |
import os | |
from pathlib import Path | |
import torch | |
from cliport import agents | |
from cliport.dataset import RavensDataset, RavensMultiTaskDataset, RavenMultiTaskDatasetBalance | |
import hydra | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataloader import default_collate | |
import IPython | |
import pytorch_lightning as pl | |
from pytorch_lightning.utilities import rank_zero_only | |
import datetime | |
import time | |
def main(cfg): | |
# Logger | |
wandb_logger = None | |
if cfg['train']['log']: | |
try: | |
wandb_logger = WandbLogger(name=cfg['tag']) | |
except: | |
pass | |
# Checkpoint saver | |
hydra_dir = Path(os.getcwd()) | |
checkpoint_path = os.path.join(cfg['train']['train_dir'], 'checkpoints') | |
last_checkpoint_path = os.path.join(checkpoint_path, 'last.ckpt') | |
last_checkpoint = last_checkpoint_path if os.path.exists(last_checkpoint_path) and cfg['train']['load_from_last_ckpt'] else None | |
checkpoint_callback = [ModelCheckpoint( | |
# monitor=cfg['wandb']['saver']['monitor'], | |
dirpath=os.path.join(checkpoint_path, 'best'), | |
save_top_k=1, | |
every_n_epochs=3, | |
save_last=True, | |
# every_n_train_steps=100 | |
)] | |
# Trainer | |
max_epochs = cfg['train']['n_steps'] * cfg['train']['batch_size'] // cfg['train']['n_demos'] | |
if cfg['train']['training_step_scale'] > 0: | |
# scale training time depending on the tasks to ensure coverage. | |
max_epochs = cfg['train']['training_step_scale'] # // cfg['train']['batch_size'] | |
trainer = Trainer( | |
accelerator='gpu', | |
devices=cfg['train']['gpu'], | |
fast_dev_run=cfg['debug'], | |
logger=wandb_logger, | |
callbacks=checkpoint_callback, | |
max_epochs=max_epochs, | |
# check_val_every_n_epoch=max_epochs // 50, | |
# resume_from_checkpoint=last_checkpoint, | |
sync_batchnorm=True, | |
log_every_n_steps=30, | |
) | |
print(f"max epochs: {max_epochs}!") | |
# Resume epoch and global_steps | |
if last_checkpoint: | |
print(f"Resuming: {last_checkpoint}") | |
# Config | |
data_dir = cfg['train']['data_dir'] | |
task = cfg['train']['task'] | |
agent_type = cfg['train']['agent'] | |
n_demos = cfg['train']['n_demos'] | |
# n_demos = cfg['train']['n_demos'] | |
# n_demos = cfg['train']['n_demos'] | |
n_val = cfg['train']['n_val'] | |
name = '{}-{}-{}'.format(task, agent_type, n_demos) | |
# Datasets | |
dataset_type = cfg['dataset']['type'] | |
if 'multi' in dataset_type: | |
train_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='train', | |
n_demos=n_demos, augment=True) | |
val_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False) | |
elif 'weighted' in dataset_type: | |
train_ds = RavenMultiTaskDatasetBalance(data_dir, cfg, group=task, mode='train', n_demos=n_demos, augment=True) | |
val_ds = RavenMultiTaskDatasetBalance(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False) | |
else: | |
train_ds = RavensDataset(os.path.join(data_dir, '{}-train'.format(task)), cfg, n_demos=n_demos, augment=True) | |
val_ds = RavensDataset(os.path.join(data_dir, '{}-val'.format(task)), cfg, n_demos=n_val, augment=False) | |
# Initialize agent | |
train_loader = DataLoader(train_ds, shuffle=True, | |
pin_memory=True, | |
batch_size=cfg['train']['batch_size'], | |
num_workers=1 ) | |
test_loader = DataLoader(val_ds, shuffle=False, | |
num_workers=1, | |
batch_size=cfg['train']['batch_size'], | |
pin_memory=True) | |
agent = agents.names[agent_type](name, cfg, train_loader, test_loader) | |
dt_string = datetime.datetime.now().strftime("%d_%m_%Y_%H:%M:%S") | |
print("current time:", dt_string) | |
start_time = time.time() | |
# Main training loop | |
trainer.fit(agent, ckpt_path=last_checkpoint) | |
print("current time:", time.time() - start_time) | |
if __name__ == '__main__': | |
main() | |