import json |
import os |
from zoedepth.utils.easydict import EasyDict as edict |
from zoedepth.utils.arg_utils import infer_type |
import pathlib |
import platform |
ROOT = pathlib.Path(__file__).parent.parent.resolve() |
HOME_DIR = os.path.expanduser("./data") |
"save_dir": os.path.expanduser("./depth_anything_finetune"), |
"project": "ZoeDepth", |
"tags": '', |
"notes": "", |
"gpu": None, |
"root": ".", |
"uid": None, |
"print_losses": False |
} |
"kitti": { |
"dataset": "kitti", |
"min_depth": 0.001, |
"max_depth": 80, |
"data_path": os.path.join(HOME_DIR, "Kitti/raw_data"), |
"gt_path": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"), |
"filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", |
"input_height": 352, |
"input_width": 1216, |
"data_path_eval": os.path.join(HOME_DIR, "Kitti/raw_data"), |
"gt_path_eval": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"), |
"filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"do_random_rotate": True, |
"degree": 1.0, |
"do_kb_crop": True, |
"garg_crop": True, |
"eigen_crop": False, |
"use_right": False |
}, |
"kitti_test": { |
"dataset": "kitti", |
"min_depth": 0.001, |
"max_depth": 80, |
"data_path": os.path.join(HOME_DIR, "Kitti/raw_data"), |
"gt_path": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"), |
"filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", |
"input_height": 352, |
"input_width": 1216, |
"data_path_eval": os.path.join(HOME_DIR, "Kitti/raw_data"), |
"gt_path_eval": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"), |
"filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"do_random_rotate": False, |
"degree": 1.0, |
"do_kb_crop": True, |
"garg_crop": True, |
"eigen_crop": False, |
"use_right": False |
}, |
"nyu": { |
"dataset": "nyu", |
"avoid_boundary": False, |
"min_depth": 1e-3, |
"max_depth": 10, |
"data_path": os.path.join(HOME_DIR, "nyu"), |
"gt_path": os.path.join(HOME_DIR, "nyu"), |
"filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", |
"input_height": 480, |
"input_width": 640, |
"data_path_eval": os.path.join(HOME_DIR, "nyu"), |
"gt_path_eval": os.path.join(HOME_DIR, "nyu"), |
"filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", |
"min_depth_eval": 1e-3, |
"max_depth_eval": 10, |
"min_depth_diff": -10, |
"max_depth_diff": 10, |
"do_random_rotate": True, |
"degree": 1.0, |
"do_kb_crop": False, |
"garg_crop": False, |
"eigen_crop": True |
}, |
"ibims": { |
"dataset": "ibims", |
"ibims_root": os.path.join(HOME_DIR, "iBims1/m1455541/ibims1_core_raw/"), |
"eigen_crop": True, |
"garg_crop": False, |
"do_kb_crop": False, |
"min_depth_eval": 0, |
"max_depth_eval": 10, |
"min_depth": 1e-3, |
"max_depth": 10 |
}, |
"sunrgbd": { |
"dataset": "sunrgbd", |
"sunrgbd_root": os.path.join(HOME_DIR, "SUNRGB-D"), |
"eigen_crop": True, |
"garg_crop": False, |
"do_kb_crop": False, |
"min_depth_eval": 0, |
"max_depth_eval": 8, |
"min_depth": 1e-3, |
"max_depth": 10 |
}, |
"diml_indoor": { |
"dataset": "diml_indoor", |
"diml_indoor_root": os.path.join(HOME_DIR, "DIML/indoor/sample/testset/"), |
"eigen_crop": True, |
"garg_crop": False, |
"do_kb_crop": False, |
"min_depth_eval": 0, |
"max_depth_eval": 10, |
"min_depth": 1e-3, |
"max_depth": 10 |
}, |
"diml_outdoor": { |
"dataset": "diml_outdoor", |
"diml_outdoor_root": os.path.join(HOME_DIR, "DIML/outdoor/test/LR"), |
"eigen_crop": False, |
"garg_crop": True, |
"do_kb_crop": False, |
"min_depth_eval": 2, |
"max_depth_eval": 80, |
"min_depth": 1e-3, |
"max_depth": 80 |
}, |
"diode_indoor": { |
"dataset": "diode_indoor", |
"diode_indoor_root": os.path.join(HOME_DIR, "DIODE/val/indoors/"), |
"eigen_crop": True, |
"garg_crop": False, |
"do_kb_crop": False, |
"min_depth_eval": 1e-3, |
"max_depth_eval": 10, |
"min_depth": 1e-3, |
"max_depth": 10 |
}, |
"diode_outdoor": { |
"dataset": "diode_outdoor", |
"diode_outdoor_root": os.path.join(HOME_DIR, "DIODE/val/outdoor/"), |
"eigen_crop": False, |
"garg_crop": True, |
"do_kb_crop": False, |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"min_depth": 1e-3, |
"max_depth": 80 |
}, |
"hypersim_test": { |
"dataset": "hypersim_test", |
"hypersim_test_root": os.path.join(HOME_DIR, "HyperSim/"), |
"eigen_crop": True, |
"garg_crop": False, |
"do_kb_crop": False, |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"min_depth": 1e-3, |
"max_depth": 10 |
}, |
"vkitti": { |
"dataset": "vkitti", |
"vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), |
"eigen_crop": False, |
"garg_crop": True, |
"do_kb_crop": True, |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"min_depth": 1e-3, |
"max_depth": 80 |
}, |
"vkitti2": { |
"dataset": "vkitti2", |
"vkitti2_root": os.path.join(HOME_DIR, "vKitti2/"), |
"eigen_crop": False, |
"garg_crop": True, |
"do_kb_crop": True, |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"min_depth": 1e-3, |
"max_depth": 80, |
}, |
"ddad": { |
"dataset": "ddad", |
"ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), |
"eigen_crop": False, |
"garg_crop": True, |
"do_kb_crop": True, |
"min_depth_eval": 1e-3, |
"max_depth_eval": 80, |
"min_depth": 1e-3, |
"max_depth": 80, |
}, |
} |
ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] |
ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] |
"dataset": "nyu", |
"distributed": True, |
"workers": 16, |
"clip_grad": 0.1, |
"use_shared_dict": False, |
"shared_dict": None, |
"use_amp": False, |
"aug": True, |
"random_crop": False, |
"random_translate": False, |
"translate_prob": 0.2, |
"max_translation": 100, |
"validate_every": 0.25, |
"log_images_every": 0.1, |
"prefetch": False, |
} |
def flatten(config, except_keys=('bin_conf')): |
def recurse(inp): |
if isinstance(inp, dict): |
for key, value in inp.items(): |
if key in except_keys: |
yield (key, value) |
if isinstance(value, dict): |
yield from recurse(value) |
else: |
yield (key, value) |
return dict(list(recurse(config))) |
def split_combined_args(kwargs): |
"""Splits the arguments that are combined with '__' into multiple arguments. |
Combined arguments should have equal number of keys and values. |
Keys are separated by '__' and Values are separated with ';'. |
For example, '__n_bins__lr=256;0.001' |
Args: |
kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. |
Returns: |
dict: Parsed dict with the combined arguments split into individual key-value pairs. |
""" |
new_kwargs = dict(kwargs) |
for key, value in kwargs.items(): |
if key.startswith("__"): |
keys = key.split("__")[1:] |
values = value.split(";") |
assert len(keys) == len( |
values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" |
for k, v in zip(keys, values): |
new_kwargs[k] = v |
return new_kwargs |
def parse_list(config, key, dtype=int): |
"""Parse a list of values for the key if the value is a string. The values are separated by a comma. |
Modifies the config in place. |
""" |
if key in config: |
if isinstance(config[key], str): |
config[key] = list(map(dtype, config[key].split(','))) |
assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] |
), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." |
def get_model_config(model_name, model_version=None): |
"""Find and parse the .json config file for the model. |
Args: |
model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. |
model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. |
Returns: |
easydict: the config dictionary for the model. |
""" |
config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" |
config_file = os.path.join(ROOT, "models", model_name, config_fname) |
if not os.path.exists(config_file): |
return None |
with open(config_file, "r") as f: |
config = edict(json.load(f)) |
if "inherit" in config.train and config.train.inherit is not None: |
inherit_config = get_model_config(config.train["inherit"]).train |
for key, value in inherit_config.items(): |
if key not in config.train: |
config.train[key] = value |
return edict(config) |
def update_model_config(config, mode, model_name, model_version=None, strict=False): |
model_config = get_model_config(model_name, model_version) |
if model_config is not None: |
config = {**config, ** |
flatten({**model_config.model, **model_config[mode]})} |
elif strict: |
raise ValueError(f"Config file for model {model_name} not found.") |
return config |
def check_choices(name, value, choices): |
if value not in choices: |
raise ValueError(f"{name} {value} not in supported choices {choices}") |
KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", |
"prefetch", "cycle_momentum"] |
def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): |
"""Main entry point to get the config for the model. |
Args: |
model_name (str): name of the desired model. |
mode (str, optional): "train" or "infer". Defaults to 'train'. |
dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. |
Keyword Args: key-value pairs of arguments to overwrite the default config. |
The order of precedence for overwriting the config is (Higher precedence first): |
# 1. overwrite_kwargs |
# 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json |
# 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json |
# 4. common_config: Default config for all models specified in COMMON_CONFIG |
Returns: |
easydict: The config dictionary for the model. |
""" |
check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) |
check_choices("Mode", mode, ["train", "infer", "eval"]) |
if mode == "train": |
check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) |
config = update_model_config(config, mode, model_name) |
version_name = overwrite_kwargs.get("version_name", config["version_name"]) |
config = update_model_config(config, mode, model_name, version_name) |
config_version = overwrite_kwargs.get("config_version", None) |
if config_version is not None: |
print("Overwriting config with config_version", config_version) |
config = update_model_config(config, mode, model_name, config_version) |
overwrite_kwargs = split_combined_args(overwrite_kwargs) |
config = {**config, **overwrite_kwargs} |
for key in KEYS_TYPE_BOOL: |
if key in config: |
config[key] = bool(config[key]) |
parse_list(config, "n_attractors") |
if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: |
bin_conf = config['bin_conf'] |
n_bins = overwrite_kwargs['n_bins'] |
new_bin_conf = [] |
for conf in bin_conf: |
conf['n_bins'] = n_bins |
new_bin_conf.append(conf) |
config['bin_conf'] = new_bin_conf |
if mode == "train": |
orig_dataset = dataset |
if dataset == "mix": |
dataset = 'nyu' |
if dataset is not None: |
config['project'] = f"MonoDepth3-{orig_dataset}" |
if dataset is not None: |
config['dataset'] = dataset |
config = {**DATASETS_CONFIG[dataset], **config} |
config['model'] = model_name |
typed_config = {k: infer_type(v) for k, v in config.items()} |
config['hostname'] = platform.node() |
return edict(typed_config) |
def change_dataset(config, new_dataset): |
config.update(DATASETS_CONFIG[new_dataset]) |
return config |