|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from zoedepth.utils.misc import count_parameters, parallelize |
|
from zoedepth.utils.config import get_config |
|
from zoedepth.utils.arg_utils import parse_unknown |
|
from zoedepth.trainers.builder import get_trainer |
|
from zoedepth.models.builder import build_model |
|
from zoedepth.data.data_mono import DepthDataLoader |
|
import torch.utils.data.distributed |
|
import torch.multiprocessing as mp |
|
import torch |
|
import numpy as np |
|
from pprint import pprint |
|
import argparse |
|
import os |
|
|
|
os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
os.environ["WANDB_START_METHOD"] = "thread" |
|
|
|
|
|
def fix_random_seed(seed: int): |
|
import random |
|
|
|
import numpy |
|
import torch |
|
|
|
random.seed(seed) |
|
numpy.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"): |
|
import glob |
|
import os |
|
|
|
from zoedepth.models.model_io import load_wts |
|
|
|
if hasattr(config, "checkpoint"): |
|
checkpoint = config.checkpoint |
|
elif hasattr(config, "ckpt_pattern"): |
|
pattern = config.ckpt_pattern |
|
matches = glob.glob(os.path.join( |
|
checkpoint_dir, f"*{pattern}*{ckpt_type}*")) |
|
if not (len(matches) > 0): |
|
raise ValueError(f"No matches found for the pattern {pattern}") |
|
|
|
checkpoint = matches[0] |
|
|
|
else: |
|
return model |
|
model = load_wts(model, checkpoint) |
|
print("Loaded weights from {0}".format(checkpoint)) |
|
return model |
|
|
|
|
|
def main_worker(gpu, ngpus_per_node, config): |
|
try: |
|
seed = config.seed if 'seed' in config and config.seed else 43 |
|
fix_random_seed(seed) |
|
|
|
config.gpu = gpu |
|
|
|
model = build_model(config) |
|
|
|
|
|
model = load_ckpt(config, model) |
|
model = parallelize(config, model) |
|
|
|
total_params = f"{round(count_parameters(model)/1e6,2)}M" |
|
config.total_params = total_params |
|
print(f"Total parameters : {total_params}") |
|
|
|
train_loader = DepthDataLoader(config, "train").data |
|
test_loader = DepthDataLoader(config, "online_eval").data |
|
|
|
trainer = get_trainer(config)( |
|
config, model, train_loader, test_loader, device=config.gpu) |
|
|
|
trainer.train() |
|
finally: |
|
import wandb |
|
wandb.finish() |
|
|
|
|
|
if __name__ == '__main__': |
|
mp.set_start_method('forkserver') |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-m", "--model", type=str, default="synunet") |
|
parser.add_argument("-d", "--dataset", type=str, default='nyu') |
|
parser.add_argument("--trainer", type=str, default=None) |
|
|
|
args, unknown_args = parser.parse_known_args() |
|
overwrite_kwargs = parse_unknown(unknown_args) |
|
|
|
overwrite_kwargs["model"] = args.model |
|
if args.trainer is not None: |
|
overwrite_kwargs["trainer"] = args.trainer |
|
|
|
config = get_config(args.model, "train", args.dataset, **overwrite_kwargs) |
|
|
|
if config.use_shared_dict: |
|
shared_dict = mp.Manager().dict() |
|
else: |
|
shared_dict = None |
|
config.shared_dict = shared_dict |
|
|
|
config.batch_size = config.bs |
|
config.mode = 'train' |
|
if config.root != "." and not os.path.isdir(config.root): |
|
os.makedirs(config.root) |
|
|
|
try: |
|
node_str = os.environ['SLURM_JOB_NODELIST'].replace( |
|
'[', '').replace(']', '') |
|
nodes = node_str.split(',') |
|
|
|
config.world_size = len(nodes) |
|
config.rank = int(os.environ['SLURM_PROCID']) |
|
|
|
|
|
except KeyError as e: |
|
|
|
config.world_size = 1 |
|
config.rank = 0 |
|
nodes = ["127.0.0.1"] |
|
|
|
if config.distributed: |
|
|
|
print(config.rank) |
|
port = np.random.randint(15000, 15025) |
|
config.dist_url = 'tcp://{}:{}'.format(nodes[0], port) |
|
print(config.dist_url) |
|
config.dist_backend = 'nccl' |
|
config.gpu = None |
|
|
|
ngpus_per_node = torch.cuda.device_count() |
|
config.num_workers = config.workers |
|
config.ngpus_per_node = ngpus_per_node |
|
print("Config:") |
|
pprint(config) |
|
if config.distributed: |
|
config.world_size = ngpus_per_node * config.world_size |
|
mp.spawn(main_worker, nprocs=ngpus_per_node, |
|
args=(ngpus_per_node, config)) |
|
else: |
|
if ngpus_per_node == 1: |
|
config.gpu = 0 |
|
main_worker(config.gpu, ngpus_per_node, config) |
|
|