|
import argparse |
|
import os |
|
import json |
|
import uuid |
|
|
|
import torch |
|
import os |
|
from torch.utils.data import random_split |
|
from torch_geometric.loader import DataLoader |
|
from data import PolyphemusDataset |
|
import torch.optim as optim |
|
|
|
from model import VAE |
|
from utils import set_seed, print_params, print_divider |
|
from training import PolyphemusTrainer, ExpDecayLRScheduler, StepBetaScheduler |
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser( |
|
description='Trains Polyphemus.' |
|
) |
|
parser.add_argument( |
|
'dataset_dir', |
|
type=str, |
|
help='Directory of the Polyphemus dataset to be used for training.' |
|
) |
|
parser.add_argument( |
|
'output_dir', |
|
type=str, |
|
help='Directory to save the output of the training.' |
|
) |
|
parser.add_argument( |
|
'config_file', |
|
type=str, |
|
help='Path to the JSON training configuration file.' |
|
) |
|
parser.add_argument( |
|
'--model_name', |
|
type=str, |
|
help='Name of the model to be trained.' |
|
) |
|
parser.add_argument( |
|
'--save_every', |
|
type=int, |
|
default=10, |
|
help="If set to n, the script will save the model every n batches. " |
|
"Default is 10." |
|
) |
|
parser.add_argument( |
|
'--print_every', |
|
type=int, |
|
default=1, |
|
help="If set to n, the script will print statistics every n batches. " |
|
"Default is 1." |
|
) |
|
parser.add_argument( |
|
'--eval', |
|
action='store_true', |
|
default=False, |
|
help='Flag to enable evaluation on a validation set.' |
|
) |
|
parser.add_argument( |
|
'--eval_every', |
|
type=int, |
|
help="If the eval flag is set, when set to n, the script will evaluate " |
|
"the model on the validation set every n batches. " |
|
"Default is every epoch." |
|
) |
|
parser.add_argument( |
|
'--use_gpu', |
|
action='store_true', |
|
default=False, |
|
help='Flag to enable or disable GPU usage. Default is False.' |
|
) |
|
parser.add_argument( |
|
'--gpu_id', |
|
type=int, |
|
default='0', |
|
help='Index of the GPU to be used. Default is 0.' |
|
) |
|
parser.add_argument( |
|
'--num_workers', |
|
type=int, |
|
default='10', |
|
help="The number of processes to use for loading the data. " |
|
"Default is 10." |
|
) |
|
parser.add_argument( |
|
'--tr_split', |
|
type=float, |
|
default='0.7', |
|
help="Percentage of samples in the dataset used for the training split." |
|
" Default is 0.7." |
|
) |
|
parser.add_argument( |
|
'--vl_split', |
|
type=float, |
|
default='0.1', |
|
help="Percentage of samples in the dataset used for the validation " |
|
"split. Default is 0.1. This value is ignored if the --eval option is " |
|
"not specified." |
|
) |
|
parser.add_argument( |
|
'--max_epochs', |
|
type=int, |
|
default='100', |
|
) |
|
parser.add_argument( |
|
'--seed', |
|
type=int |
|
) |
|
|
|
args = parser.parse_args() |
|
print_divider() |
|
|
|
if args.seed is not None: |
|
set_seed(args.seed) |
|
|
|
device = torch.device("cuda") if args.use_gpu else torch.device("cpu") |
|
if args.use_gpu: |
|
torch.cuda.set_device(args.gpu_id) |
|
|
|
|
|
print("Loading the configuration file {}...".format(args.config_file)) |
|
|
|
|
|
with open(args.config_file, 'r') as f: |
|
training_config = json.load(f) |
|
|
|
n_bars = training_config['model']['n_bars'] |
|
batch_size = training_config['batch_size'] |
|
|
|
print("Preparing datasets and dataloaders...") |
|
|
|
dataset = PolyphemusDataset(args.dataset_dir, n_bars) |
|
|
|
tr_len = int(args.tr_split * len(dataset)) |
|
|
|
if args.eval: |
|
vl_len = int(args.vl_split * len(dataset)) |
|
ts_len = len(dataset) - tr_len - vl_len |
|
lengths = (tr_len, vl_len, ts_len) |
|
else: |
|
ts_len = len(dataset) - tr_len |
|
lengths = (tr_len, ts_len) |
|
|
|
split = random_split(dataset, lengths) |
|
tr_set = split[0] |
|
vl_set = split[1] if args.eval else None |
|
|
|
trainloader = DataLoader(tr_set, batch_size=batch_size, shuffle=True, |
|
num_workers=args.num_workers) |
|
if args.eval: |
|
validloader = DataLoader(vl_set, batch_size=batch_size, shuffle=False, |
|
num_workers=args.num_workers) |
|
eval_every = len(trainloader) |
|
else: |
|
validloader = None |
|
eval_every = None |
|
|
|
|
|
model_name = (args.model_name if args.model_name is not None |
|
else str(uuid.uuid1())) |
|
model_dir = os.path.join(args.output_dir, model_name) |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
os.makedirs(model_dir, exist_ok=False) |
|
|
|
|
|
print("Creating the model and moving it on {} device...".format(device)) |
|
vae = VAE(**training_config['model'], device=device).to(device) |
|
print_params(vae) |
|
print() |
|
|
|
|
|
optimizer = optim.Adam(vae.parameters(), **training_config['optimizer']) |
|
lr_scheduler = ExpDecayLRScheduler( |
|
optimizer=optimizer, |
|
**training_config['lr_scheduler'] |
|
) |
|
beta_scheduler = StepBetaScheduler(**training_config['beta_scheduler']) |
|
|
|
|
|
|
|
config_path = os.path.join(model_dir, 'configuration') |
|
torch.save(training_config, config_path) |
|
|
|
print("Starting training...") |
|
print_divider() |
|
|
|
trainer = PolyphemusTrainer( |
|
model_dir, |
|
vae, |
|
optimizer, |
|
lr_scheduler=lr_scheduler, |
|
beta_scheduler=beta_scheduler, |
|
save_every=args.save_every, |
|
print_every=args.print_every, |
|
eval_every=eval_every, |
|
device=device |
|
) |
|
trainer.train(trainloader, validloader=validloader, epochs=args.max_epochs) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|