polyphemus / train.py
EmanueleCosenza's picture
Working version
d896bd4
raw
history blame
6.06 kB
import argparse
import os
import json
import uuid
import torch
import os
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from data import PolyphemusDataset
import torch.optim as optim
from model import VAE
from utils import set_seed, print_params, print_divider
from training import PolyphemusTrainer, ExpDecayLRScheduler, StepBetaScheduler
def main():
parser = argparse.ArgumentParser(
description='Trains Polyphemus.'
)
parser.add_argument(
'dataset_dir',
type=str,
help='Directory of the Polyphemus dataset to be used for training.'
)
parser.add_argument(
'output_dir',
type=str,
help='Directory to save the output of the training.'
)
parser.add_argument(
'config_file',
type=str,
help='Path to the JSON training configuration file.'
)
parser.add_argument(
'--model_name',
type=str,
help='Name of the model to be trained.'
)
parser.add_argument(
'--save_every',
type=int,
default=10,
help="If set to n, the script will save the model every n batches. "
"Default is 10."
)
parser.add_argument(
'--print_every',
type=int,
default=1,
help="If set to n, the script will print statistics every n batches. "
"Default is 1."
)
parser.add_argument(
'--eval',
action='store_true',
default=False,
help='Flag to enable evaluation on a validation set.'
)
parser.add_argument(
'--eval_every',
type=int,
help="If the eval flag is set, when set to n, the script will evaluate "
"the model on the validation set every n batches. "
"Default is every epoch."
)
parser.add_argument(
'--use_gpu',
action='store_true',
default=False,
help='Flag to enable or disable GPU usage. Default is False.'
)
parser.add_argument(
'--gpu_id',
type=int,
default='0',
help='Index of the GPU to be used. Default is 0.'
)
parser.add_argument(
'--num_workers',
type=int,
default='10',
help="The number of processes to use for loading the data. "
"Default is 10."
)
parser.add_argument(
'--tr_split',
type=float,
default='0.7',
help="Percentage of samples in the dataset used for the training split."
" Default is 0.7."
)
parser.add_argument(
'--vl_split',
type=float,
default='0.1',
help="Percentage of samples in the dataset used for the validation "
"split. Default is 0.1. This value is ignored if the --eval option is "
"not specified."
)
parser.add_argument(
'--max_epochs',
type=int,
default='100',
)
parser.add_argument(
'--seed',
type=int
)
args = parser.parse_args()
print_divider()
if args.seed is not None:
set_seed(args.seed)
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
if args.use_gpu:
torch.cuda.set_device(args.gpu_id)
# Load config file
print("Loading the configuration file {}...".format(args.config_file))
# Load structure tensor from file
with open(args.config_file, 'r') as f:
training_config = json.load(f)
n_bars = training_config['model']['n_bars']
batch_size = training_config['batch_size']
print("Preparing datasets and dataloaders...")
dataset = PolyphemusDataset(args.dataset_dir, n_bars)
tr_len = int(args.tr_split * len(dataset))
if args.eval:
vl_len = int(args.vl_split * len(dataset))
ts_len = len(dataset) - tr_len - vl_len
lengths = (tr_len, vl_len, ts_len)
else:
ts_len = len(dataset) - tr_len
lengths = (tr_len, ts_len)
split = random_split(dataset, lengths)
tr_set = split[0]
vl_set = split[1] if args.eval else None
trainloader = DataLoader(tr_set, batch_size=batch_size, shuffle=True,
num_workers=args.num_workers)
if args.eval:
validloader = DataLoader(vl_set, batch_size=batch_size, shuffle=False,
num_workers=args.num_workers)
eval_every = len(trainloader)
else:
validloader = None
eval_every = None
model_name = (args.model_name if args.model_name is not None
else str(uuid.uuid1()))
model_dir = os.path.join(args.output_dir, model_name)
# Create output directory if it does not exist
os.makedirs(args.output_dir, exist_ok=True)
# Create model output directory (raise error if it already exists to avoid
# overwriting a trained model)
os.makedirs(model_dir, exist_ok=False)
# Create the model
print("Creating the model and moving it on {} device...".format(device))
vae = VAE(**training_config['model'], device=device).to(device)
print_params(vae)
print()
# Creating optimizer and schedulers
optimizer = optim.Adam(vae.parameters(), **training_config['optimizer'])
lr_scheduler = ExpDecayLRScheduler(
optimizer=optimizer,
**training_config['lr_scheduler']
)
beta_scheduler = StepBetaScheduler(**training_config['beta_scheduler'])
# Save config
config_path = os.path.join(model_dir, 'configuration')
torch.save(training_config, config_path)
print("Starting training...")
print_divider()
trainer = PolyphemusTrainer(
model_dir,
vae,
optimizer,
lr_scheduler=lr_scheduler,
beta_scheduler=beta_scheduler,
save_every=args.save_every,
print_every=args.print_every,
eval_every=eval_every,
device=device
)
trainer.train(trainloader, validloader=validloader, epochs=args.max_epochs)
if __name__ == '__main__':
main()