|
"""Compute depth maps for images in the input folder. |
|
""" |
|
import os |
|
import glob |
|
import torch |
|
import cv2 |
|
import argparse |
|
|
|
import util.io |
|
|
|
from torchvision.transforms import Compose |
|
|
|
from dpt.models import DPTDepthModel |
|
from dpt.midas_net import MidasNet_large |
|
from dpt.transforms import Resize, NormalizeImage, PrepareForNet |
|
|
|
|
|
|
|
|
|
def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): |
|
"""Run MonoDepthNN to compute depth maps. |
|
|
|
Args: |
|
input_path (str): path to input folder |
|
output_path (str): path to output folder |
|
model_path (str): path to saved model |
|
""" |
|
print("initialize") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print("device: %s" % device) |
|
|
|
|
|
if model_type == "dpt_large": |
|
net_w = net_h = 384 |
|
model = DPTDepthModel( |
|
path=model_path, |
|
backbone="vitl16_384", |
|
non_negative=True, |
|
enable_attention_hooks=False, |
|
) |
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif model_type == "dpt_hybrid": |
|
net_w = net_h = 384 |
|
model = DPTDepthModel( |
|
path=model_path, |
|
backbone="vitb_rn50_384", |
|
non_negative=True, |
|
enable_attention_hooks=False, |
|
) |
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif model_type == "dpt_hybrid_kitti": |
|
net_w = 1216 |
|
net_h = 352 |
|
|
|
model = DPTDepthModel( |
|
path=model_path, |
|
scale=0.00006016, |
|
shift=0.00579, |
|
invert=True, |
|
backbone="vitb_rn50_384", |
|
non_negative=True, |
|
enable_attention_hooks=False, |
|
) |
|
|
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif model_type == "dpt_hybrid_nyu": |
|
net_w = 640 |
|
net_h = 480 |
|
|
|
model = DPTDepthModel( |
|
path=model_path, |
|
scale=0.000305, |
|
shift=0.1378, |
|
invert=True, |
|
backbone="vitb_rn50_384", |
|
non_negative=True, |
|
enable_attention_hooks=False, |
|
) |
|
|
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif model_type == "midas_v21": |
|
net_w = net_h = 384 |
|
|
|
model = MidasNet_large(model_path, non_negative=True) |
|
normalization = NormalizeImage( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
else: |
|
assert ( |
|
False |
|
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]" |
|
|
|
transform = Compose( |
|
[ |
|
Resize( |
|
net_w, |
|
net_h, |
|
resize_target=None, |
|
keep_aspect_ratio=True, |
|
ensure_multiple_of=32, |
|
resize_method="minimal", |
|
image_interpolation_method=cv2.INTER_CUBIC, |
|
), |
|
normalization, |
|
PrepareForNet(), |
|
] |
|
) |
|
|
|
model.eval() |
|
|
|
if optimize == True and device == torch.device("cuda"): |
|
model = model.to(memory_format=torch.channels_last) |
|
model = model.half() |
|
|
|
model.to(device) |
|
|
|
|
|
img_names = glob.glob(os.path.join(input_path, "*")) |
|
num_images = len(img_names) |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
print("start processing") |
|
for ind, img_name in enumerate(img_names): |
|
if os.path.isdir(img_name): |
|
continue |
|
|
|
print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) |
|
|
|
|
|
img = util.io.read_image(img_name) |
|
|
|
if args.kitti_crop is True: |
|
height, width, _ = img.shape |
|
top = height - 352 |
|
left = (width - 1216) // 2 |
|
img = img[top : top + 352, left : left + 1216, :] |
|
|
|
img_input = transform({"image": img})["image"] |
|
|
|
|
|
with torch.no_grad(): |
|
sample = torch.from_numpy(img_input).to(device).unsqueeze(0) |
|
|
|
if optimize == True and device == torch.device("cuda"): |
|
sample = sample.to(memory_format=torch.channels_last) |
|
sample = sample.half() |
|
|
|
prediction = model.forward(sample) |
|
prediction = ( |
|
torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=img.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
.squeeze() |
|
.cpu() |
|
.numpy() |
|
) |
|
|
|
if model_type == "dpt_hybrid_kitti": |
|
prediction *= 256 |
|
|
|
if model_type == "dpt_hybrid_nyu": |
|
prediction *= 1000.0 |
|
|
|
filename = os.path.join( |
|
output_path, os.path.splitext(os.path.basename(img_name))[0] |
|
) |
|
util.io.write_depth(filename, prediction, bits=2, absolute_depth=args.absolute_depth) |
|
|
|
print("finished") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"-i", "--input_path", default="input", help="folder with input images" |
|
) |
|
|
|
parser.add_argument( |
|
"-o", |
|
"--output_path", |
|
default="output_monodepth", |
|
help="folder for output images", |
|
) |
|
|
|
parser.add_argument( |
|
"-m", "--model_weights", default=None, help="path to model weights" |
|
) |
|
|
|
parser.add_argument( |
|
"-t", |
|
"--model_type", |
|
default="dpt_hybrid", |
|
help="model type [dpt_large|dpt_hybrid|midas_v21]", |
|
) |
|
|
|
parser.add_argument("--kitti_crop", dest="kitti_crop", action="store_true") |
|
parser.add_argument("--absolute_depth", dest="absolute_depth", action="store_true") |
|
|
|
parser.add_argument("--optimize", dest="optimize", action="store_true") |
|
parser.add_argument("--no-optimize", dest="optimize", action="store_false") |
|
|
|
parser.set_defaults(optimize=True) |
|
parser.set_defaults(kitti_crop=False) |
|
parser.set_defaults(absolute_depth=False) |
|
|
|
args = parser.parse_args() |
|
|
|
default_models = { |
|
"midas_v21": "weights/midas_v21-f6b98070.pt", |
|
"dpt_large": "weights/dpt_large-midas-2f21e586.pt", |
|
"dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt", |
|
"dpt_hybrid_kitti": "weights/dpt_hybrid_kitti-cb926ef4.pt", |
|
"dpt_hybrid_nyu": "weights/dpt_hybrid_nyu-2ce69ec7.pt", |
|
} |
|
|
|
if args.model_weights is None: |
|
args.model_weights = default_models[args.model_type] |
|
|
|
|
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
run( |
|
args.input_path, |
|
args.output_path, |
|
args.model_weights, |
|
args.model_type, |
|
args.optimize, |
|
) |
|
|