import argparse import logging import os import pprint import random import warnings import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist from torch.utils.data import DataLoader from torch.optim import AdamW import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from dataset.hypersim import Hypersim from dataset.kitti import KITTI from dataset.vkitti2 import VKITTI2 from depth_anything_v2.dpt import DepthAnythingV2 from util.dist_helper import setup_distributed from util.loss import SiLogLoss from util.metric import eval_depth from util.utils import init_log parser = argparse.ArgumentParser(description='Depth Anything V2 for Metric Depth Estimation') parser.add_argument('--encoder', default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg']) parser.add_argument('--dataset', default='hypersim', choices=['hypersim', 'vkitti']) parser.add_argument('--img-size', default=518, type=int) parser.add_argument('--min-depth', default=0.001, type=float) parser.add_argument('--max-depth', default=20, type=float) parser.add_argument('--epochs', default=40, type=int) parser.add_argument('--bs', default=2, type=int) parser.add_argument('--lr', default=0.000005, type=float) parser.add_argument('--pretrained-from', type=str) parser.add_argument('--save-path', type=str, required=True) parser.add_argument('--local-rank', default=0, type=int) parser.add_argument('--port', default=None, type=int) def main(): args = parser.parse_args() warnings.simplefilter('ignore', np.RankWarning) logger = init_log('global', logging.INFO) logger.propagate = 0 rank, world_size = setup_distributed(port=args.port) if rank == 0: all_args = {**vars(args), 'ngpus': world_size} logger.info('{}\n'.format(pprint.pformat(all_args))) writer = SummaryWriter(args.save_path) cudnn.enabled = True cudnn.benchmark = True size = (args.img_size, args.img_size) if args.dataset == 'hypersim': trainset = Hypersim('dataset/splits/hypersim/train.txt', 'train', size=size) elif args.dataset == 'vkitti': trainset = VKITTI2('dataset/splits/vkitti2/train.txt', 'train', size=size) else: raise NotImplementedError trainsampler = torch.utils.data.distributed.DistributedSampler(trainset) trainloader = DataLoader(trainset, batch_size=args.bs, pin_memory=True, num_workers=4, drop_last=True, sampler=trainsampler) if args.dataset == 'hypersim': valset = Hypersim('dataset/splits/hypersim/val.txt', 'val', size=size) elif args.dataset == 'vkitti': valset = KITTI('dataset/splits/kitti/val.txt', 'val', size=size) else: raise NotImplementedError valsampler = torch.utils.data.distributed.DistributedSampler(valset) valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=4, drop_last=True, sampler=valsampler) local_rank = int(os.environ["LOCAL_RANK"]) model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } model = DepthAnythingV2(**{**model_configs[args.encoder], 'max_depth': args.max_depth}) if args.pretrained_from: model.load_state_dict({k: v for k, v in torch.load(args.pretrained_from, map_location='cpu').items() if 'pretrained' in k}, strict=False) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.cuda(local_rank) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False, output_device=local_rank, find_unused_parameters=True) criterion = SiLogLoss().cuda(local_rank) optimizer = AdamW([{'params': [param for name, param in model.named_parameters() if 'pretrained' in name], 'lr': args.lr}, {'params': [param for name, param in model.named_parameters() if 'pretrained' not in name], 'lr': args.lr * 10.0}], lr=args.lr, betas=(0.9, 0.999), weight_decay=0.01) total_iters = args.epochs * len(trainloader) previous_best = {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100} for epoch in range(args.epochs): if rank == 0: logger.info('===========> Epoch: {:}/{:}, d1: {:.3f}, d2: {:.3f}, d3: {:.3f}'.format(epoch, args.epochs, previous_best['d1'], previous_best['d2'], previous_best['d3'])) logger.info('===========> Epoch: {:}/{:}, abs_rel: {:.3f}, sq_rel: {:.3f}, rmse: {:.3f}, rmse_log: {:.3f}, ' 'log10: {:.3f}, silog: {:.3f}'.format( epoch, args.epochs, previous_best['abs_rel'], previous_best['sq_rel'], previous_best['rmse'], previous_best['rmse_log'], previous_best['log10'], previous_best['silog'])) trainloader.sampler.set_epoch(epoch + 1) model.train() total_loss = 0 for i, sample in enumerate(trainloader): optimizer.zero_grad() img, depth, valid_mask = sample['image'].cuda(), sample['depth'].cuda(), sample['valid_mask'].cuda() if random.random() < 0.5: img = img.flip(-1) depth = depth.flip(-1) valid_mask = valid_mask.flip(-1) pred = model(img) loss = criterion(pred, depth, (valid_mask == 1) & (depth >= args.min_depth) & (depth <= args.max_depth)) loss.backward() optimizer.step() total_loss += loss.item() iters = epoch * len(trainloader) + i lr = args.lr * (1 - iters / total_iters) ** 0.9 optimizer.param_groups[0]["lr"] = lr optimizer.param_groups[1]["lr"] = lr * 10.0 if rank == 0: writer.add_scalar('train/loss', loss.item(), iters) if rank == 0 and i % 100 == 0: logger.info('Iter: {}/{}, LR: {:.7f}, Loss: {:.3f}'.format(i, len(trainloader), optimizer.param_groups[0]['lr'], loss.item())) model.eval() results = {'d1': torch.tensor([0.0]).cuda(), 'd2': torch.tensor([0.0]).cuda(), 'd3': torch.tensor([0.0]).cuda(), 'abs_rel': torch.tensor([0.0]).cuda(), 'sq_rel': torch.tensor([0.0]).cuda(), 'rmse': torch.tensor([0.0]).cuda(), 'rmse_log': torch.tensor([0.0]).cuda(), 'log10': torch.tensor([0.0]).cuda(), 'silog': torch.tensor([0.0]).cuda()} nsamples = torch.tensor([0.0]).cuda() for i, sample in enumerate(valloader): img, depth, valid_mask = sample['image'].cuda().float(), sample['depth'].cuda()[0], sample['valid_mask'].cuda()[0] with torch.no_grad(): pred = model(img) pred = F.interpolate(pred[:, None], depth.shape[-2:], mode='bilinear', align_corners=True)[0, 0] valid_mask = (valid_mask == 1) & (depth >= args.min_depth) & (depth <= args.max_depth) if valid_mask.sum() < 10: continue cur_results = eval_depth(pred[valid_mask], depth[valid_mask]) for k in results.keys(): results[k] += cur_results[k] nsamples += 1 torch.distributed.barrier() for k in results.keys(): dist.reduce(results[k], dst=0) dist.reduce(nsamples, dst=0) if rank == 0: logger.info('==========================================================================================') logger.info('{:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}, {:>8}'.format(*tuple(results.keys()))) logger.info('{:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}, {:8.3f}'.format(*tuple([(v / nsamples).item() for v in results.values()]))) logger.info('==========================================================================================') print() for name, metric in results.items(): writer.add_scalar(f'eval/{name}', (metric / nsamples).item(), epoch) for k in results.keys(): if k in ['d1', 'd2', 'd3']: previous_best[k] = max(previous_best[k], (results[k] / nsamples).item()) else: previous_best[k] = min(previous_best[k], (results[k] / nsamples).item()) if rank == 0: checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'previous_best': previous_best, } torch.save(checkpoint, os.path.join(args.save_path, 'latest.pth')) if __name__ == '__main__': main()