|
import argparse |
|
|
|
import numpy as np |
|
|
|
import os |
|
|
|
import shutil |
|
|
|
import torch |
|
import torch.optim as optim |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from tqdm import tqdm |
|
|
|
import warnings |
|
|
|
from lib.dataset import MegaDepthDataset |
|
from lib.exceptions import NoGradientError |
|
from lib.loss import loss_function |
|
from lib.model import D2Net |
|
|
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda:0" if use_cuda else "cpu") |
|
|
|
|
|
torch.manual_seed(1) |
|
if use_cuda: |
|
torch.cuda.manual_seed(1) |
|
np.random.seed(1) |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Training script') |
|
|
|
parser.add_argument( |
|
'--dataset_path', type=str, required=True, |
|
help='path to the dataset' |
|
) |
|
parser.add_argument( |
|
'--scene_info_path', type=str, required=True, |
|
help='path to the processed scenes' |
|
) |
|
|
|
parser.add_argument( |
|
'--preprocessing', type=str, default='caffe', |
|
help='image preprocessing (caffe or torch)' |
|
) |
|
parser.add_argument( |
|
'--model_file', type=str, default='models/d2_ots.pth', |
|
help='path to the full model' |
|
) |
|
|
|
parser.add_argument( |
|
'--num_epochs', type=int, default=10, |
|
help='number of training epochs' |
|
) |
|
parser.add_argument( |
|
'--lr', type=float, default=1e-3, |
|
help='initial learning rate' |
|
) |
|
parser.add_argument( |
|
'--batch_size', type=int, default=1, |
|
help='batch size' |
|
) |
|
parser.add_argument( |
|
'--num_workers', type=int, default=4, |
|
help='number of workers for data loading' |
|
) |
|
|
|
parser.add_argument( |
|
'--use_validation', dest='use_validation', action='store_true', |
|
help='use the validation split' |
|
) |
|
parser.set_defaults(use_validation=False) |
|
|
|
parser.add_argument( |
|
'--log_interval', type=int, default=250, |
|
help='loss logging interval' |
|
) |
|
|
|
parser.add_argument( |
|
'--log_file', type=str, default='log.txt', |
|
help='loss logging file' |
|
) |
|
|
|
parser.add_argument( |
|
'--plot', dest='plot', action='store_true', |
|
help='plot training pairs' |
|
) |
|
parser.set_defaults(plot=False) |
|
|
|
parser.add_argument( |
|
'--checkpoint_directory', type=str, default='checkpoints', |
|
help='directory for training checkpoints' |
|
) |
|
parser.add_argument( |
|
'--checkpoint_prefix', type=str, default='d2', |
|
help='prefix for training checkpoints' |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
print(args) |
|
|
|
|
|
if args.plot: |
|
plot_path = 'train_vis' |
|
if os.path.isdir(plot_path): |
|
print('[Warning] Plotting directory already exists.') |
|
else: |
|
os.mkdir(plot_path) |
|
|
|
|
|
model = D2Net( |
|
model_file=args.model_file, |
|
use_cuda=use_cuda |
|
) |
|
|
|
|
|
optimizer = optim.Adam( |
|
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr |
|
) |
|
|
|
|
|
if args.use_validation: |
|
validation_dataset = MegaDepthDataset( |
|
scene_list_path='megadepth_utils/valid_scenes.txt', |
|
scene_info_path=args.scene_info_path, |
|
base_path=args.dataset_path, |
|
train=False, |
|
preprocessing=args.preprocessing, |
|
pairs_per_scene=25 |
|
) |
|
validation_dataloader = DataLoader( |
|
validation_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers |
|
) |
|
|
|
training_dataset = MegaDepthDataset( |
|
scene_list_path='megadepth_utils/train_scenes.txt', |
|
scene_info_path=args.scene_info_path, |
|
base_path=args.dataset_path, |
|
preprocessing=args.preprocessing |
|
) |
|
training_dataloader = DataLoader( |
|
training_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers |
|
) |
|
|
|
|
|
|
|
def process_epoch( |
|
epoch_idx, |
|
model, loss_function, optimizer, dataloader, device, |
|
log_file, args, train=True |
|
): |
|
epoch_losses = [] |
|
|
|
torch.set_grad_enabled(train) |
|
|
|
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) |
|
for batch_idx, batch in progress_bar: |
|
if train: |
|
optimizer.zero_grad() |
|
|
|
batch['train'] = train |
|
batch['epoch_idx'] = epoch_idx |
|
batch['batch_idx'] = batch_idx |
|
batch['batch_size'] = args.batch_size |
|
batch['preprocessing'] = args.preprocessing |
|
batch['log_interval'] = args.log_interval |
|
|
|
try: |
|
loss = loss_function(model, batch, device, plot=args.plot) |
|
except NoGradientError: |
|
continue |
|
|
|
current_loss = loss.data.cpu().numpy()[0] |
|
epoch_losses.append(current_loss) |
|
|
|
progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) |
|
|
|
if batch_idx % args.log_interval == 0: |
|
log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( |
|
'train' if train else 'valid', |
|
epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) |
|
)) |
|
|
|
if train: |
|
loss.backward() |
|
optimizer.step() |
|
|
|
log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( |
|
'train' if train else 'valid', |
|
epoch_idx, |
|
np.mean(epoch_losses) |
|
)) |
|
log_file.flush() |
|
|
|
return np.mean(epoch_losses) |
|
|
|
|
|
|
|
if os.path.isdir(args.checkpoint_directory): |
|
print('[Warning] Checkpoint directory already exists.') |
|
else: |
|
os.mkdir(args.checkpoint_directory) |
|
|
|
|
|
|
|
if os.path.exists(args.log_file): |
|
print('[Warning] Log file already exists.') |
|
log_file = open(args.log_file, 'a+') |
|
|
|
|
|
train_loss_history = [] |
|
validation_loss_history = [] |
|
if args.use_validation: |
|
validation_dataset.build_dataset() |
|
min_validation_loss = process_epoch( |
|
0, |
|
model, loss_function, optimizer, validation_dataloader, device, |
|
log_file, args, |
|
train=False |
|
) |
|
|
|
|
|
for epoch_idx in range(1, args.num_epochs + 1): |
|
|
|
training_dataset.build_dataset() |
|
train_loss_history.append( |
|
process_epoch( |
|
epoch_idx, |
|
model, loss_function, optimizer, training_dataloader, device, |
|
log_file, args |
|
) |
|
) |
|
|
|
if args.use_validation: |
|
validation_loss_history.append( |
|
process_epoch( |
|
epoch_idx, |
|
model, loss_function, optimizer, validation_dataloader, device, |
|
log_file, args, |
|
train=False |
|
) |
|
) |
|
|
|
|
|
checkpoint_path = os.path.join( |
|
args.checkpoint_directory, |
|
'%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx) |
|
) |
|
checkpoint = { |
|
'args': args, |
|
'epoch_idx': epoch_idx, |
|
'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'train_loss_history': train_loss_history, |
|
'validation_loss_history': validation_loss_history |
|
} |
|
torch.save(checkpoint, checkpoint_path) |
|
if ( |
|
args.use_validation and |
|
validation_loss_history[-1] < min_validation_loss |
|
): |
|
min_validation_loss = validation_loss_history[-1] |
|
best_checkpoint_path = os.path.join( |
|
args.checkpoint_directory, |
|
'%s.best.pth' % args.checkpoint_prefix |
|
) |
|
shutil.copy(checkpoint_path, best_checkpoint_path) |
|
|
|
|
|
log_file.close() |
|
|