# MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Zhenyu Li import os import cv2 import argparse from zoedepth.utils.config import get_config_user from zoedepth.models.builder import build_model from zoedepth.utils.arg_utils import parse_unknown import numpy as np from zoedepth.models.base_models.midas import Resize from torchvision.transforms import Compose from torchvision.transforms import Normalize import torch import torch.nn as nn import matplotlib.pyplot as plt import copy from zoedepth.utils.misc import get_boundaries from zoedepth.utils.misc import compute_metrics, RunningAverageDict from tqdm import tqdm import matplotlib import torch.nn.functional as F from zoedepth.data.middleburry import readPFM import random import imageio from PIL import Image def load_state_dict(model, state_dict): """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. DataParallel prefixes state_dict keys with 'module.' when saving. If the model is not a DataParallel model but the state_dict is, then prefixes are removed. If the model is a DataParallel model but the state_dict is not, then prefixes are added. """ state_dict = state_dict.get('model', state_dict) # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' do_prefix = isinstance( model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) state = {} for k, v in state_dict.items(): if k.startswith('module.') and not do_prefix: k = k[7:] if not k.startswith('module.') and do_prefix: k = 'module.' + k state[k] = v model.load_state_dict(state, strict=True) # model.load_state_dict(state, strict=False) print("Loaded successfully") return model def load_wts(model, checkpoint_path): ckpt = torch.load(checkpoint_path, map_location='cpu') return load_state_dict(model, ckpt) def load_ckpt(model, checkpoint): model = load_wts(model, checkpoint) print("Loaded weights from {0}".format(checkpoint)) return model #### def dataset def read_image(path, dataset_name): if dataset_name == 'u4k': img = np.fromfile(open(path, 'rb'), dtype=np.uint8).reshape(2160, 3840, 3) / 255.0 img = img.astype(np.float32)[:, :, ::-1].copy() elif dataset_name == 'mid': img = cv2.imread(path) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 img = F.interpolate(torch.tensor(img).unsqueeze(dim=0).permute(0, 3, 1, 2), IMG_RESOLUTION, mode='bicubic', align_corners=True) img = img.squeeze().permute(1, 2, 0) elif dataset_name == 'nyu': img = Image.open(path) img = np.asarray(img, dtype=np.float32) / 255.0 else: img = cv2.imread(path) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 print(img.shape) img = F.interpolate(torch.tensor(img).unsqueeze(dim=0).permute(0, 3, 1, 2), IMG_RESOLUTION, mode='bicubic', align_corners=True) img = img.squeeze().permute(1, 2, 0) return img class Images: def __init__(self, root_dir, files, index, dataset_name=None): self.root_dir = root_dir name = files[index] self.dataset_name = dataset_name self.rgb_image = read_image(os.path.join(self.root_dir, name), dataset_name) name = name.replace(".jpg", "") name = name.replace(".png", "") name = name.replace(".jpeg", "") self.name = name class DepthMap: def __init__(self, root_dir, files, index, dataset_name, pred=False): self.root_dir = root_dir name = files[index] gt_path = os.path.join(self.root_dir, name) if dataset_name == 'u4k': depth_factor = gt_path.replace('val_gt', 'val_factor') depth_factor = depth_factor.replace('.npy', '.txt') with open(depth_factor, 'r') as f: df = f.readline() df = float(df) gt_disp = np.load(gt_path, mmap_mode='c') gt_disp = gt_disp.astype(np.float32) edges = get_boundaries(gt_disp, th=1, dilation=0) gt_depth = df/gt_disp self.gt = gt_depth self.edge = edges elif dataset_name == 'gta': gt_depth = imageio.imread(gt_path) gt_depth = np.array(gt_depth).astype(np.float32) / 256 edges = get_boundaries(gt_depth, th=1, dilation=0) self.gt = gt_depth self.edge = edges elif dataset_name == 'mid': depth_factor = gt_path.replace('gts', 'calibs') depth_factor = depth_factor.replace('.pfm', '.txt') with open(depth_factor, 'r') as f: ext_l = f.readlines() cam_info = ext_l[0].strip() cam_info_f = float(cam_info.split(' ')[0].split('[')[1]) base = float(ext_l[3].strip().split('=')[1]) doffs = float(ext_l[2].strip().split('=')[1]) depth_factor = base * cam_info_f height = 1840 width = 2300 disp_gt, scale = readPFM(gt_path) disp_gt = disp_gt.astype(np.float32) disp_gt_copy = disp_gt.copy() disp_gt = disp_gt invalid_mask = disp_gt == np.inf depth_gt = depth_factor / (disp_gt + doffs) depth_gt = depth_gt / 1000 depth_gt[invalid_mask] = 0 # set to a invalid number disp_gt_copy[invalid_mask] = 0 edges = get_boundaries(disp_gt_copy, th=1, dilation=0) self.gt = depth_gt self.edge = edges elif dataset_name == 'nyu': if pred: depth_gt = np.load(gt_path.replace('png', 'npy')) depth_gt = nn.functional.interpolate( torch.tensor(depth_gt).unsqueeze(dim=0).unsqueeze(dim=0), (480, 640), mode='bilinear', align_corners=True).squeeze().numpy() edges = get_boundaries(depth_gt, th=1, dilation=0) else: depth_gt = np.asarray(Image.open(gt_path), dtype=np.float32) / 1000 edges = get_boundaries(depth_gt, th=1, dilation=0) self.gt = depth_gt self.edge = edges else: raise NotImplementedError name = name.replace(".npy", "") # u4k name = name.replace(".exr", "") # gta self.name = name class ImageDataset: def __init__(self, rgb_image_dir, gt_dir=None, dataset_name=''): self.rgb_image_dir = rgb_image_dir self.files = sorted(os.listdir(self.rgb_image_dir)) self.gt_dir = gt_dir self.dataset_name = dataset_name if gt_dir is not None: self.gt_dir = gt_dir self.gt_files = sorted(os.listdir(self.gt_dir)) def __len__(self): return len(self.files) def __getitem__(self, index): if self.dataset_name == 'nyu': return Images(self.rgb_image_dir, self.files, index, self.dataset_name), DepthMap(self.gt_dir, self.gt_files, index, self.dataset_name), DepthMap('/ibex/ai/home/liz0l/codes/ZoeDepth/nfs/save/nyu', self.gt_files, index, self.dataset_name, pred=True) if self.gt_dir is not None: return Images(self.rgb_image_dir, self.files, index, self.dataset_name), DepthMap(self.gt_dir, self.gt_files, index, self.dataset_name) else: return Images(self.rgb_image_dir, self.files, index) def crop(img, crop_bbox): crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox templete = torch.zeros((1, 1, img.shape[-2], img.shape[-1]), dtype=torch.float) templete[:, :, crop_y1:crop_y2, crop_x1:crop_x2] = 1.0 img = img[:, :, crop_y1:crop_y2, crop_x1:crop_x2] return img, templete # def generatemask(size): # # Generates a Guassian mask # mask = np.zeros(size, dtype=np.float32) # sigma = int(size[0]/16) # k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) # mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1 # mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) # mask = (mask - mask.min()) / (mask.max() - mask.min()) # mask = mask.astype(np.float32) # return mask def generatemask(size): # Generates a Guassian mask mask = np.zeros(size, dtype=np.float32) sigma = int(size[0]/16) k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) mask[int(0.1*size[0]):size[0] - int(0.1*size[0]), int(0.1*size[1]): size[1] - int(0.1*size[1])] = 1 mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) mask = (mask - mask.min()) / (mask.max() - mask.min()) mask = mask.astype(np.float32) return mask def generatemask_coarse(size): # Generates a Guassian mask mask = np.zeros(size, dtype=np.float32) sigma = int(size[0]/64) k_size = int(2 * np.ceil(2 * int(size[0]/64)) + 1) mask[int(0.001*size[0]):size[0] - int(0.001*size[0]), int(0.001*size[1]): size[1] - int(0.001*size[1])] = 1 mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) mask = (mask - mask.min()) / (mask.max() - mask.min()) mask = mask.astype(np.float32) return mask class RunningAverageMap: """A dictionary of running averages.""" def __init__(self, average_map, count_map): self.average_map = average_map self.count_map = count_map self.average_map = self.average_map / self.count_map def update(self, pred_map, ct_map): self.average_map = (pred_map + self.count_map * self.average_map) / (self.count_map + ct_map) self.count_map = self.count_map + ct_map # default size [540, 960] # x_start, y_start = [0, 540, 1080, 1620], [0, 960, 1920, 2880] def regular_tile(model, image, offset_x=0, offset_y=0, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False): # crop size # height = 540 # width = 960 height = CROP_SIZE[0] width = CROP_SIZE[1] assert offset_x >= 0 and offset_y >= 0 tile_num_x = (IMG_RESOLUTION[1] - offset_x) // width tile_num_y = (IMG_RESOLUTION[0] - offset_y) // height x_start = [width * x + offset_x for x in range(tile_num_x)] y_start = [height * y + offset_y for y in range(tile_num_y)] imgs_crop = [] crop_areas = [] bboxs_roi = [] bboxs_raw = [] if iter_pred is not None: iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) iter_priors = [] for x in x_start: # w for y in y_start: # h bbox = (int(y), int(y+height), int(x), int(x+width)) img_crop, crop_area = crop(image, bbox) imgs_crop.append(img_crop) crop_areas.append(crop_area) crop_y1, crop_y2, crop_x1, crop_x2 = bbox bbox_roi = torch.tensor([crop_x1 / IMG_RESOLUTION[1] * 512, crop_y1 / IMG_RESOLUTION[0] * 384, crop_x2 / IMG_RESOLUTION[1] * 512, crop_y2 / IMG_RESOLUTION[0] * 384]) bboxs_roi.append(bbox_roi) bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) bboxs_raw.append(bbox_raw) if iter_pred is not None: iter_prior, _ = crop(iter_pred, bbox) iter_priors.append(iter_prior) crop_areas = torch.cat(crop_areas, dim=0) imgs_crop = torch.cat(imgs_crop, dim=0) bboxs_roi = torch.stack(bboxs_roi, dim=0) bboxs_raw = torch.stack(bboxs_raw, dim=0) if iter_pred is not None: iter_priors = torch.cat(iter_priors, dim=0) iter_priors = TRANSFORM(iter_priors) iter_priors = iter_priors.cuda().float() crop_areas = TRANSFORM(crop_areas) imgs_crop = TRANSFORM(imgs_crop) imgs_crop = imgs_crop.cuda().float() bboxs_roi = bboxs_roi.cuda().float() crop_areas = crop_areas.cuda().float() img_lr = img_lr.cuda().float() pred_depth_crops = [] with torch.no_grad(): for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): if iter_pred is not None: iter_prior = iter_priors[i].unsqueeze(dim=0) else: iter_prior = None if i == 0: out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) whole_depth_pred = out_dict['coarse_depth_pred'] # return whole_depth_pred.squeeze() # pred_depth_crop = out_dict['fine_depth_pred'] pred_depth_crop = out_dict['metric_depth'] else: pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] # pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['fine_depth_pred'] pred_depth_crop = nn.functional.interpolate( pred_depth_crop, (height, width), mode='bilinear', align_corners=True) # pred_depth_crop = nn.functional.interpolate( # pred_depth_crop, (height, width), mode='nearest') pred_depth_crops.append(pred_depth_crop.squeeze()) whole_depth_pred = whole_depth_pred.squeeze() whole_depth_pred = nn.functional.interpolate(whole_depth_pred.unsqueeze(dim=0).unsqueeze(dim=0), IMG_RESOLUTION, mode='bilinear', align_corners=True).squeeze() ####### stich part inner_idx = 0 init_flag = False if offset_x == 0 and offset_y == 0: init_flag = True # pred_depth = whole_depth_pred pred_depth = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) else: iter_pred = iter_pred.squeeze() pred_depth = iter_pred blur_mask = generatemask((height, width)) + 1e-3 count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) for ii, x in enumerate(x_start): for jj, y in enumerate(y_start): if init_flag: # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) count_map[y: y+height, x: x+width] = blur_mask pred_depth[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask else: # ensemble with running mean if blr_mask: blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = blur_mask pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask avg_depth_map.update(pred_map, count_map) else: if boundary != 0: count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] avg_depth_map.update(pred_map, count_map) else: count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = 1 pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] avg_depth_map.update(pred_map, count_map) inner_idx += 1 if init_flag: avg_depth_map = RunningAverageMap(pred_depth, count_map) # blur_mask = generatemask_coarse(IMG_RESOLUTION) # blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) # count_map = (1 - blur_mask) # pred_map = whole_depth_pred * (1 - blur_mask) # avg_depth_map.update(pred_map, count_map) return avg_depth_map def regular_tile_param(model, image, offset_x=0, offset_y=0, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False, crop_size=None, img_resolution=None, transform=None): # crop size # height = 540 # width = 960 height = crop_size[0] width = crop_size[1] assert offset_x >= 0 and offset_y >= 0 tile_num_x = (img_resolution[1] - offset_x) // width tile_num_y = (img_resolution[0] - offset_y) // height x_start = [width * x + offset_x for x in range(tile_num_x)] y_start = [height * y + offset_y for y in range(tile_num_y)] imgs_crop = [] crop_areas = [] bboxs_roi = [] bboxs_raw = [] if iter_pred is not None: iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) iter_priors = [] for x in x_start: # w for y in y_start: # h bbox = (int(y), int(y+height), int(x), int(x+width)) img_crop, crop_area = crop(image, bbox) imgs_crop.append(img_crop) crop_areas.append(crop_area) crop_y1, crop_y2, crop_x1, crop_x2 = bbox bbox_roi = torch.tensor([crop_x1 / img_resolution[1] * 512, crop_y1 / img_resolution[0] * 384, crop_x2 / img_resolution[1] * 512, crop_y2 / img_resolution[0] * 384]) bboxs_roi.append(bbox_roi) bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) bboxs_raw.append(bbox_raw) if iter_pred is not None: iter_prior, _ = crop(iter_pred, bbox) iter_priors.append(iter_prior) crop_areas = torch.cat(crop_areas, dim=0) imgs_crop = torch.cat(imgs_crop, dim=0) bboxs_roi = torch.stack(bboxs_roi, dim=0) bboxs_raw = torch.stack(bboxs_raw, dim=0) if iter_pred is not None: iter_priors = torch.cat(iter_priors, dim=0) iter_priors = transform(iter_priors) iter_priors = iter_priors.to(image.device).float() crop_areas = transform(crop_areas) imgs_crop = transform(imgs_crop) imgs_crop = imgs_crop.to(image.device).float() bboxs_roi = bboxs_roi.to(image.device).float() crop_areas = crop_areas.to(image.device).float() img_lr = img_lr.to(image.device).float() pred_depth_crops = [] with torch.no_grad(): for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): if iter_pred is not None: iter_prior = iter_priors[i].unsqueeze(dim=0) else: iter_prior = None if i == 0: out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) whole_depth_pred = out_dict['coarse_depth_pred'] # return whole_depth_pred.squeeze() # pred_depth_crop = out_dict['fine_depth_pred'] pred_depth_crop = out_dict['metric_depth'] else: pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] # pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['fine_depth_pred'] pred_depth_crop = nn.functional.interpolate( pred_depth_crop, (height, width), mode='bilinear', align_corners=True) # pred_depth_crop = nn.functional.interpolate( # pred_depth_crop, (height, width), mode='nearest') pred_depth_crops.append(pred_depth_crop.squeeze()) whole_depth_pred = whole_depth_pred.squeeze() whole_depth_pred = nn.functional.interpolate(whole_depth_pred.unsqueeze(dim=0).unsqueeze(dim=0), img_resolution, mode='bilinear', align_corners=True).squeeze() ####### stich part inner_idx = 0 init_flag = False if offset_x == 0 and offset_y == 0: init_flag = True # pred_depth = whole_depth_pred pred_depth = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) else: iter_pred = iter_pred.squeeze() pred_depth = iter_pred blur_mask = generatemask((height, width)) + 1e-3 count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) for ii, x in enumerate(x_start): for jj, y in enumerate(y_start): if init_flag: # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp # pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) count_map[y: y+height, x: x+width] = blur_mask pred_depth[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask else: # ensemble with running mean if blr_mask: blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = blur_mask pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask avg_depth_map.update(pred_map, count_map) else: if boundary != 0: count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] avg_depth_map.update(pred_map, count_map) else: count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = 1 pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] avg_depth_map.update(pred_map, count_map) inner_idx += 1 if init_flag: avg_depth_map = RunningAverageMap(pred_depth, count_map) # blur_mask = generatemask_coarse(img_resolution) # blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) # count_map = (1 - blur_mask) # pred_map = whole_depth_pred * (1 - blur_mask) # avg_depth_map.update(pred_map, count_map) return avg_depth_map def random_tile(model, image, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False): height = CROP_SIZE[0] width = CROP_SIZE[1] x_start = [random.randint(0, IMG_RESOLUTION[1] - width - 1)] y_start = [random.randint(0, IMG_RESOLUTION[0] - height - 1)] imgs_crop = [] crop_areas = [] bboxs_roi = [] bboxs_raw = [] if iter_pred is not None: iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) iter_priors = [] for x in x_start: # w for y in y_start: # h bbox = (int(y), int(y+height), int(x), int(x+width)) img_crop, crop_area = crop(image, bbox) imgs_crop.append(img_crop) crop_areas.append(crop_area) crop_y1, crop_y2, crop_x1, crop_x2 = bbox bbox_roi = torch.tensor([crop_x1 / IMG_RESOLUTION[1] * 512, crop_y1 / IMG_RESOLUTION[0] * 384, crop_x2 / IMG_RESOLUTION[1] * 512, crop_y2 / IMG_RESOLUTION[0] * 384]) bboxs_roi.append(bbox_roi) bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) bboxs_raw.append(bbox_raw) if iter_pred is not None: iter_prior, _ = crop(iter_pred, bbox) iter_priors.append(iter_prior) crop_areas = torch.cat(crop_areas, dim=0) imgs_crop = torch.cat(imgs_crop, dim=0) bboxs_roi = torch.stack(bboxs_roi, dim=0) bboxs_raw = torch.stack(bboxs_raw, dim=0) if iter_pred is not None: iter_priors = torch.cat(iter_priors, dim=0) iter_priors = TRANSFORM(iter_priors) iter_priors = iter_priors.cuda().float() crop_areas = TRANSFORM(crop_areas) imgs_crop = TRANSFORM(imgs_crop) imgs_crop = imgs_crop.cuda().float() bboxs_roi = bboxs_roi.cuda().float() crop_areas = crop_areas.cuda().float() img_lr = img_lr.cuda().float() pred_depth_crops = [] with torch.no_grad(): for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): if iter_pred is not None: iter_prior = iter_priors[i].unsqueeze(dim=0) else: iter_prior = None if i == 0: out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) whole_depth_pred = out_dict['coarse_depth_pred'] pred_depth_crop = out_dict['metric_depth'] # return whole_depth_pred.squeeze() else: pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] pred_depth_crop = nn.functional.interpolate( pred_depth_crop, (height, width), mode='bilinear', align_corners=True) # pred_depth_crop = nn.functional.interpolate( # pred_depth_crop, (height, width), mode='nearest') pred_depth_crops.append(pred_depth_crop.squeeze()) whole_depth_pred = whole_depth_pred.squeeze() ####### stich part inner_idx = 0 init_flag = False iter_pred = iter_pred.squeeze() pred_depth = iter_pred blur_mask = generatemask((height, width)) + 1e-3 for ii, x in enumerate(x_start): for jj, y in enumerate(y_start): if init_flag: # wont be here crop_temp = copy.deepcopy(whole_depth_pred[y: y+height, x: x+width]) blur_mask = torch.ones((height, width)) blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx]+ (1 - blur_mask) * crop_temp else: if blr_mask: blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = blur_mask pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask avg_depth_map.update(pred_map, count_map) else: # ensemble with running mean if boundary != 0: count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] avg_depth_map.update(pred_map, count_map) else: count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = 1 pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] avg_depth_map.update(pred_map, count_map) inner_idx += 1 if avg_depth_map is None: return pred_depth def random_tile_param(model, image, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False, crop_size=None, img_resolution=None, transform=None): height = crop_size[0] width = crop_size[1] x_start = [random.randint(0, img_resolution[1] - width - 1)] y_start = [random.randint(0, img_resolution[0] - height - 1)] imgs_crop = [] crop_areas = [] bboxs_roi = [] bboxs_raw = [] if iter_pred is not None: iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0) iter_priors = [] for x in x_start: # w for y in y_start: # h bbox = (int(y), int(y+height), int(x), int(x+width)) img_crop, crop_area = crop(image, bbox) imgs_crop.append(img_crop) crop_areas.append(crop_area) crop_y1, crop_y2, crop_x1, crop_x2 = bbox bbox_roi = torch.tensor([crop_x1 / img_resolution[1] * 512, crop_y1 / img_resolution[0] * 384, crop_x2 / img_resolution[1] * 512, crop_y2 / img_resolution[0] * 384]) bboxs_roi.append(bbox_roi) bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2]) bboxs_raw.append(bbox_raw) if iter_pred is not None: iter_prior, _ = crop(iter_pred, bbox) iter_priors.append(iter_prior) crop_areas = torch.cat(crop_areas, dim=0) imgs_crop = torch.cat(imgs_crop, dim=0) bboxs_roi = torch.stack(bboxs_roi, dim=0) bboxs_raw = torch.stack(bboxs_raw, dim=0) if iter_pred is not None: iter_priors = torch.cat(iter_priors, dim=0) iter_priors = transform(iter_priors) iter_priors = iter_priors.cuda().float() crop_areas = transform(crop_areas) imgs_crop = transform(imgs_crop) imgs_crop = imgs_crop.cuda().float() bboxs_roi = bboxs_roi.cuda().float() crop_areas = crop_areas.cuda().float() img_lr = img_lr.cuda().float() pred_depth_crops = [] with torch.no_grad(): for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)): if iter_pred is not None: iter_prior = iter_priors[i].unsqueeze(dim=0) else: iter_prior = None if i == 0: out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None) whole_depth_pred = out_dict['coarse_depth_pred'] pred_depth_crop = out_dict['metric_depth'] # return whole_depth_pred.squeeze() else: pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth'] pred_depth_crop = nn.functional.interpolate( pred_depth_crop, (height, width), mode='bilinear', align_corners=True) # pred_depth_crop = nn.functional.interpolate( # pred_depth_crop, (height, width), mode='nearest') pred_depth_crops.append(pred_depth_crop.squeeze()) whole_depth_pred = whole_depth_pred.squeeze() ####### stich part inner_idx = 0 init_flag = False iter_pred = iter_pred.squeeze() pred_depth = iter_pred blur_mask = generatemask((height, width)) + 1e-3 for ii, x in enumerate(x_start): for jj, y in enumerate(y_start): if init_flag: # wont be here crop_temp = copy.deepcopy(whole_depth_pred[y: y+height, x: x+width]) blur_mask = torch.ones((height, width)) blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx]+ (1 - blur_mask) * crop_temp else: if blr_mask: blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device) count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = blur_mask pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask avg_depth_map.update(pred_map, count_map) else: # ensemble with running mean if boundary != 0: count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1 pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary] avg_depth_map.update(pred_map, count_map) else: count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) count_map[y: y+height, x: x+width] = 1 pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device) pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] avg_depth_map.update(pred_map, count_map) inner_idx += 1 if avg_depth_map is None: return pred_depth def colorize_infer(value, cmap='magma_r', vmin=None, vmax=None): # normalize vmin = value.min() if vmin is None else vmin # vmax = value.max() if vmax is None else vmax vmax = np.percentile(value, 95) if vmax is None else vmax if vmin != vmax: value = (value - vmin) / (vmax - vmin) # vmin..vmax else: value = value * 0. cmapper = matplotlib.cm.get_cmap(cmap) value = cmapper(value, bytes=True) # ((1)xhxwx4) value = value[:, :, :3] # bgr -> rgb rgb_value = value[..., ::-1] return rgb_value def colorize(value, vmin=None, vmax=None, cmap='turbo_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None, dataset_name=None): """Converts a depth map to a color image. Args: value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. Returns: numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) """ if isinstance(value, torch.Tensor): value = value.detach().cpu().numpy() value = value.squeeze() if invalid_mask is None: invalid_mask = value == invalid_val mask = np.logical_not(invalid_mask) # normalize # vmin = np.percentile(value[mask],2) if vmin is None else vmin # vmin = value.min() if vmin is None else vmin # vmax = np.percentile(value[mask],95) if vmax is None else vmax # mid gt if dataset_name == 'mid': vmin = np.percentile(value[mask],2) if vmin is None else vmin vmax = np.percentile(value[mask],85) if vmax is None else vmax else: vmin = value.min() if vmin is None else vmin vmax = np.percentile(value[mask],95) if vmax is None else vmax if vmin != vmax: value = (value - vmin) / (vmax - vmin) # vmin..vmax else: # Avoid 0-division value = value * 0. # squeeze last dim if it exists # grey out the invalid values value[invalid_mask] = np.nan cmapper = matplotlib.cm.get_cmap(cmap) if value_transform: value = value_transform(value) # value = value / value.max() value = cmapper(value, bytes=True) # (nxmx4) # img = value[:, :, :] img = value[...] img[invalid_mask] = background_color # return img.transpose((2, 0, 1)) if gamma_corrected: # gamma correction img = img / 255 img = np.power(img, 2.2) img = img * 255 img = img.astype(np.uint8) return img def rescale(A, lbound=0, ubound=1): """ Rescale an array to [lbound, ubound]. Parameters: - A: Input data as numpy array - lbound: Lower bound of the scale, default is 0. - ubound: Upper bound of the scale, default is 1. Returns: - Rescaled array """ A_min = np.min(A) A_max = np.max(A) return (ubound - lbound) * (A - A_min) / (A_max - A_min) + lbound def run(model, dataset, gt_dir=None, show_path=None, show=False, save_flag=False, save_path=None, mode=None, dataset_name=None, base_zoed=False, blr_mask=False): data_len = len(dataset) if gt_dir is not None: metrics_avg = RunningAverageDict() for image_ind in tqdm(range(data_len)): if dataset_name == 'nyu': images, depths, pred_depths = dataset[image_ind] else: if gt_dir is None: images = dataset[image_ind] else: images, depths = dataset[image_ind] # Load image from dataset img = torch.tensor(images.rgb_image).unsqueeze(dim=0).permute(0, 3, 1, 2) # shape: 1, 3, h, w img_lr = TRANSFORM(img) if base_zoed: with torch.no_grad(): pred_depth = model(img.cuda())['metric_depth'].squeeze() avg_depth_map = RunningAverageMap(pred_depth) else: # pred_depth, count_map = regular_tile(model, img, offset_x=0, offset_y=0, img_lr=img_lr) # avg_depth_map = RunningAverageMap(pred_depth, count_map) avg_depth_map = regular_tile(model, img, offset_x=0, offset_y=0, img_lr=img_lr) if mode== 'p16': pass elif mode== 'p49': regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=0, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) regular_tile(model, img, offset_x=0, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) elif mode[0] == 'r': regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=0, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) regular_tile(model, img, offset_x=0, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) for i in tqdm(range(int(mode[1:]))): random_tile(model, img, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask) if show: color_map = copy.deepcopy(avg_depth_map.average_map) color_map = colorize_infer(color_map.detach().cpu().numpy()) cv2.imwrite(os.path.join(show_path, '{}.png'.format(images.name)), color_map) if save_flag: np.save(os.path.join(save_path, '{}.npy'.format(images.name)), avg_depth_map.average_map.squeeze().detach().cpu().numpy()) # np.save(os.path.join(save_path, '{}.npy'.format(images.name)), depths.gt) if gt_dir is not None: if dataset_name == 'nyu': metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=10, garg_crop=False, eigen_crop=True, dataset='nyu', pred_depths=torch.tensor(pred_depths.gt)) # metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=10, garg_crop=False, eigen_crop=True, dataset='nyu') else: metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=80, garg_crop=False, eigen_crop=False, dataset='') metrics_avg.update(metrics) print(metrics) if gt_dir is not None: print(metrics_avg.get_value()) else: print("successful!") return avg_depth_map #### if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--rgb_dir', type=str, required=True) parser.add_argument('--show_path', type=str, required=None) parser.add_argument("--ckp_path", type=str, required=True) parser.add_argument("-m", "--model", type=str, default="zoedepth") parser.add_argument("--model_cfg_path", type=str, default="") parser.add_argument("--gt_dir", type=str, default=None) parser.add_argument("--dataset_name", type=str, default=None) parser.add_argument("--show", action='store_true') parser.add_argument("--save", action='store_true') parser.add_argument("--save_path", type=str, default=None) parser.add_argument("--img_resolution", type=str, default=None) parser.add_argument("--crop_size", type=str, default=None) parser.add_argument("--mode", type=str, default=None) parser.add_argument("--base_zoed", action='store_true') parser.add_argument("--boundary", type=int, default=0) parser.add_argument("--blur_mask", action='store_true') args, unknown_args = parser.parse_known_args() # prepare some global args global IMG_RESOLUTION if args.dataset_name == 'u4k': IMG_RESOLUTION = (2160, 3840) elif args.dataset_name == 'gta': IMG_RESOLUTION = (1080, 1920) elif args.dataset_name == 'nyu': IMG_RESOLUTION = (480, 640) else: IMG_RESOLUTION = (2160, 3840) global TRANSFORM TRANSFORM = Compose([Resize(512, 384, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")]) global BOUNDARY BOUNDARY = args.boundary if args.img_resolution is not None: IMG_RESOLUTION = (int(args.img_resolution.split('x')[0]), int(args.img_resolution.split('x')[1])) global CROP_SIZE CROP_SIZE = (int(IMG_RESOLUTION[0] // 4), int(IMG_RESOLUTION[1] // 4)) if args.crop_size is not None: CROP_SIZE = (int(args.crop_size.split('x')[0]), int(args.crop_size.split('x')[1])) print("\nCurrent image resolution: {}\n Current crop size: {}".format(IMG_RESOLUTION, CROP_SIZE)) overwrite_kwargs = parse_unknown(unknown_args) overwrite_kwargs['model_cfg_path'] = args.model_cfg_path overwrite_kwargs["model"] = args.model # blur_mask_crop = generatemask(CROP_SIZE) # plt.imshow(blur_mask_crop) # plt.savefig('./nfs/results_show/crop_mask.png') # blur_mask_crop = generatemask_coarse(IMG_RESOLUTION) # plt.imshow(blur_mask_crop) # plt.savefig('./nfs/results_show/whole_mask.png') config = get_config_user(args.model, **overwrite_kwargs) config["pretrained_resource"] = '' model = build_model(config) model = load_ckpt(model, args.ckp_path) model.eval() model.cuda() # Create dataset from input images dataset_custom = ImageDataset(args.rgb_dir, args.gt_dir, args.dataset_name) # start running if args.show: os.makedirs(args.show_path, exist_ok=True) if args.save: os.makedirs(args.save_path, exist_ok=True) run(model, dataset_custom, args.gt_dir, args.show_path, args.show, args.save, args.save_path, mode=args.mode, dataset_name=args.dataset_name, base_zoed=args.base_zoed, blr_mask=args.blur_mask)