File size: 6,058 Bytes
d896bd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import argparse
import os
import json
import uuid
import torch
import os
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from data import PolyphemusDataset
import torch.optim as optim
from model import VAE
from utils import set_seed, print_params, print_divider
from training import PolyphemusTrainer, ExpDecayLRScheduler, StepBetaScheduler
def main():
parser = argparse.ArgumentParser(
description='Trains Polyphemus.'
)
parser.add_argument(
'dataset_dir',
type=str,
help='Directory of the Polyphemus dataset to be used for training.'
)
parser.add_argument(
'output_dir',
type=str,
help='Directory to save the output of the training.'
)
parser.add_argument(
'config_file',
type=str,
help='Path to the JSON training configuration file.'
)
parser.add_argument(
'--model_name',
type=str,
help='Name of the model to be trained.'
)
parser.add_argument(
'--save_every',
type=int,
default=10,
help="If set to n, the script will save the model every n batches. "
"Default is 10."
)
parser.add_argument(
'--print_every',
type=int,
default=1,
help="If set to n, the script will print statistics every n batches. "
"Default is 1."
)
parser.add_argument(
'--eval',
action='store_true',
default=False,
help='Flag to enable evaluation on a validation set.'
)
parser.add_argument(
'--eval_every',
type=int,
help="If the eval flag is set, when set to n, the script will evaluate "
"the model on the validation set every n batches. "
"Default is every epoch."
)
parser.add_argument(
'--use_gpu',
action='store_true',
default=False,
help='Flag to enable or disable GPU usage. Default is False.'
)
parser.add_argument(
'--gpu_id',
type=int,
default='0',
help='Index of the GPU to be used. Default is 0.'
)
parser.add_argument(
'--num_workers',
type=int,
default='10',
help="The number of processes to use for loading the data. "
"Default is 10."
)
parser.add_argument(
'--tr_split',
type=float,
default='0.7',
help="Percentage of samples in the dataset used for the training split."
" Default is 0.7."
)
parser.add_argument(
'--vl_split',
type=float,
default='0.1',
help="Percentage of samples in the dataset used for the validation "
"split. Default is 0.1. This value is ignored if the --eval option is "
"not specified."
)
parser.add_argument(
'--max_epochs',
type=int,
default='100',
)
parser.add_argument(
'--seed',
type=int
)
args = parser.parse_args()
print_divider()
if args.seed is not None:
set_seed(args.seed)
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
if args.use_gpu:
torch.cuda.set_device(args.gpu_id)
# Load config file
print("Loading the configuration file {}...".format(args.config_file))
# Load structure tensor from file
with open(args.config_file, 'r') as f:
training_config = json.load(f)
n_bars = training_config['model']['n_bars']
batch_size = training_config['batch_size']
print("Preparing datasets and dataloaders...")
dataset = PolyphemusDataset(args.dataset_dir, n_bars)
tr_len = int(args.tr_split * len(dataset))
if args.eval:
vl_len = int(args.vl_split * len(dataset))
ts_len = len(dataset) - tr_len - vl_len
lengths = (tr_len, vl_len, ts_len)
else:
ts_len = len(dataset) - tr_len
lengths = (tr_len, ts_len)
split = random_split(dataset, lengths)
tr_set = split[0]
vl_set = split[1] if args.eval else None
trainloader = DataLoader(tr_set, batch_size=batch_size, shuffle=True,
num_workers=args.num_workers)
if args.eval:
validloader = DataLoader(vl_set, batch_size=batch_size, shuffle=False,
num_workers=args.num_workers)
eval_every = len(trainloader)
else:
validloader = None
eval_every = None
model_name = (args.model_name if args.model_name is not None
else str(uuid.uuid1()))
model_dir = os.path.join(args.output_dir, model_name)
# Create output directory if it does not exist
os.makedirs(args.output_dir, exist_ok=True)
# Create model output directory (raise error if it already exists to avoid
# overwriting a trained model)
os.makedirs(model_dir, exist_ok=False)
# Create the model
print("Creating the model and moving it on {} device...".format(device))
vae = VAE(**training_config['model'], device=device).to(device)
print_params(vae)
print()
# Creating optimizer and schedulers
optimizer = optim.Adam(vae.parameters(), **training_config['optimizer'])
lr_scheduler = ExpDecayLRScheduler(
optimizer=optimizer,
**training_config['lr_scheduler']
)
beta_scheduler = StepBetaScheduler(**training_config['beta_scheduler'])
# Save config
config_path = os.path.join(model_dir, 'configuration')
torch.save(training_config, config_path)
print("Starting training...")
print_divider()
trainer = PolyphemusTrainer(
model_dir,
vae,
optimizer,
lr_scheduler=lr_scheduler,
beta_scheduler=beta_scheduler,
save_every=args.save_every,
print_every=args.print_every,
eval_every=eval_every,
device=device
)
trainer.train(trainloader, validloader=validloader, epochs=args.max_epochs)
if __name__ == '__main__':
main()
|