import argparse import os import json import uuid import torch import os from torch.utils.data import random_split from torch_geometric.loader import DataLoader from data import PolyphemusDataset import torch.optim as optim from model import VAE from utils import set_seed, print_params, print_divider from training import PolyphemusTrainer, ExpDecayLRScheduler, StepBetaScheduler def main(): parser = argparse.ArgumentParser( description='Trains Polyphemus.' ) parser.add_argument( 'dataset_dir', type=str, help='Directory of the Polyphemus dataset to be used for training.' ) parser.add_argument( 'output_dir', type=str, help='Directory to save the output of the training.' ) parser.add_argument( 'config_file', type=str, help='Path to the JSON training configuration file.' ) parser.add_argument( '--model_name', type=str, help='Name of the model to be trained.' ) parser.add_argument( '--save_every', type=int, default=10, help="If set to n, the script will save the model every n batches. " "Default is 10." ) parser.add_argument( '--print_every', type=int, default=1, help="If set to n, the script will print statistics every n batches. " "Default is 1." ) parser.add_argument( '--eval', action='store_true', default=False, help='Flag to enable evaluation on a validation set.' ) parser.add_argument( '--eval_every', type=int, help="If the eval flag is set, when set to n, the script will evaluate " "the model on the validation set every n batches. " "Default is every epoch." ) parser.add_argument( '--use_gpu', action='store_true', default=False, help='Flag to enable or disable GPU usage. Default is False.' ) parser.add_argument( '--gpu_id', type=int, default='0', help='Index of the GPU to be used. Default is 0.' ) parser.add_argument( '--num_workers', type=int, default='10', help="The number of processes to use for loading the data. " "Default is 10." ) parser.add_argument( '--tr_split', type=float, default='0.7', help="Percentage of samples in the dataset used for the training split." " Default is 0.7." ) parser.add_argument( '--vl_split', type=float, default='0.1', help="Percentage of samples in the dataset used for the validation " "split. Default is 0.1. This value is ignored if the --eval option is " "not specified." ) parser.add_argument( '--max_epochs', type=int, default='100', ) parser.add_argument( '--seed', type=int ) args = parser.parse_args() print_divider() if args.seed is not None: set_seed(args.seed) device = torch.device("cuda") if args.use_gpu else torch.device("cpu") if args.use_gpu: torch.cuda.set_device(args.gpu_id) # Load config file print("Loading the configuration file {}...".format(args.config_file)) # Load structure tensor from file with open(args.config_file, 'r') as f: training_config = json.load(f) n_bars = training_config['model']['n_bars'] batch_size = training_config['batch_size'] print("Preparing datasets and dataloaders...") dataset = PolyphemusDataset(args.dataset_dir, n_bars) tr_len = int(args.tr_split * len(dataset)) if args.eval: vl_len = int(args.vl_split * len(dataset)) ts_len = len(dataset) - tr_len - vl_len lengths = (tr_len, vl_len, ts_len) else: ts_len = len(dataset) - tr_len lengths = (tr_len, ts_len) split = random_split(dataset, lengths) tr_set = split[0] vl_set = split[1] if args.eval else None trainloader = DataLoader(tr_set, batch_size=batch_size, shuffle=True, num_workers=args.num_workers) if args.eval: validloader = DataLoader(vl_set, batch_size=batch_size, shuffle=False, num_workers=args.num_workers) eval_every = len(trainloader) else: validloader = None eval_every = None model_name = (args.model_name if args.model_name is not None else str(uuid.uuid1())) model_dir = os.path.join(args.output_dir, model_name) # Create output directory if it does not exist os.makedirs(args.output_dir, exist_ok=True) # Create model output directory (raise error if it already exists to avoid # overwriting a trained model) os.makedirs(model_dir, exist_ok=False) # Create the model print("Creating the model and moving it on {} device...".format(device)) vae = VAE(**training_config['model'], device=device).to(device) print_params(vae) print() # Creating optimizer and schedulers optimizer = optim.Adam(vae.parameters(), **training_config['optimizer']) lr_scheduler = ExpDecayLRScheduler( optimizer=optimizer, **training_config['lr_scheduler'] ) beta_scheduler = StepBetaScheduler(**training_config['beta_scheduler']) # Save config config_path = os.path.join(model_dir, 'configuration') torch.save(training_config, config_path) print("Starting training...") print_divider() trainer = PolyphemusTrainer( model_dir, vae, optimizer, lr_scheduler=lr_scheduler, beta_scheduler=beta_scheduler, save_every=args.save_every, print_every=args.print_every, eval_every=eval_every, device=device ) trainer.train(trainloader, validloader=validloader, epochs=args.max_epochs) if __name__ == '__main__': main()