|
import numpy as np |
|
from functools import reduce |
|
|
|
def sparse_bilateral_filtering( |
|
depth, image, config, HR=False, mask=None, gsHR=True, edge_id=None, num_iter=None, num_gs_iter=None, spdb=False |
|
): |
|
""" |
|
config: |
|
- filter_size |
|
""" |
|
import time |
|
|
|
save_images = [] |
|
save_depths = [] |
|
save_discontinuities = [] |
|
vis_depth = depth.copy() |
|
backup_vis_depth = vis_depth.copy() |
|
|
|
depth_max = vis_depth.max() |
|
depth_min = vis_depth.min() |
|
vis_image = image.copy() |
|
for i in range(num_iter): |
|
if isinstance(config["filter_size"], list): |
|
window_size = config["filter_size"][i] |
|
else: |
|
window_size = config["filter_size"] |
|
vis_image = image.copy() |
|
save_images.append(vis_image) |
|
save_depths.append(vis_depth) |
|
u_over, b_over, l_over, r_over = vis_depth_discontinuity(vis_depth, config, mask=mask) |
|
vis_image[u_over > 0] = np.array([0, 0, 0]) |
|
vis_image[b_over > 0] = np.array([0, 0, 0]) |
|
vis_image[l_over > 0] = np.array([0, 0, 0]) |
|
vis_image[r_over > 0] = np.array([0, 0, 0]) |
|
|
|
discontinuity_map = (u_over + b_over + l_over + r_over).clip(0.0, 1.0) |
|
discontinuity_map[depth == 0] = 1 |
|
save_discontinuities.append(discontinuity_map) |
|
if mask is not None: |
|
discontinuity_map[mask == 0] = 0 |
|
vis_depth = bilateral_filter( |
|
vis_depth, config, discontinuity_map=discontinuity_map, HR=HR, mask=mask, window_size=window_size |
|
) |
|
|
|
return save_images, save_depths |
|
|
|
|
|
def vis_depth_discontinuity(depth, config, vis_diff=False, label=False, mask=None): |
|
""" |
|
config: |
|
- |
|
""" |
|
if label == False: |
|
disp = 1./depth |
|
u_diff = (disp[1:, :] - disp[:-1, :])[:-1, 1:-1] |
|
b_diff = (disp[:-1, :] - disp[1:, :])[1:, 1:-1] |
|
l_diff = (disp[:, 1:] - disp[:, :-1])[1:-1, :-1] |
|
r_diff = (disp[:, :-1] - disp[:, 1:])[1:-1, 1:] |
|
if mask is not None: |
|
u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] |
|
b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] |
|
l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] |
|
r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] |
|
u_diff = u_diff * u_mask |
|
b_diff = b_diff * b_mask |
|
l_diff = l_diff * l_mask |
|
r_diff = r_diff * r_mask |
|
u_over = (np.abs(u_diff) > config['depth_threshold']).astype(np.float32) |
|
b_over = (np.abs(b_diff) > config['depth_threshold']).astype(np.float32) |
|
l_over = (np.abs(l_diff) > config['depth_threshold']).astype(np.float32) |
|
r_over = (np.abs(r_diff) > config['depth_threshold']).astype(np.float32) |
|
else: |
|
disp = depth |
|
u_diff = (disp[1:, :] * disp[:-1, :])[:-1, 1:-1] |
|
b_diff = (disp[:-1, :] * disp[1:, :])[1:, 1:-1] |
|
l_diff = (disp[:, 1:] * disp[:, :-1])[1:-1, :-1] |
|
r_diff = (disp[:, :-1] * disp[:, 1:])[1:-1, 1:] |
|
if mask is not None: |
|
u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] |
|
b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] |
|
l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] |
|
r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] |
|
u_diff = u_diff * u_mask |
|
b_diff = b_diff * b_mask |
|
l_diff = l_diff * l_mask |
|
r_diff = r_diff * r_mask |
|
u_over = (np.abs(u_diff) > 0).astype(np.float32) |
|
b_over = (np.abs(b_diff) > 0).astype(np.float32) |
|
l_over = (np.abs(l_diff) > 0).astype(np.float32) |
|
r_over = (np.abs(r_diff) > 0).astype(np.float32) |
|
u_over = np.pad(u_over, 1, mode='constant') |
|
b_over = np.pad(b_over, 1, mode='constant') |
|
l_over = np.pad(l_over, 1, mode='constant') |
|
r_over = np.pad(r_over, 1, mode='constant') |
|
u_diff = np.pad(u_diff, 1, mode='constant') |
|
b_diff = np.pad(b_diff, 1, mode='constant') |
|
l_diff = np.pad(l_diff, 1, mode='constant') |
|
r_diff = np.pad(r_diff, 1, mode='constant') |
|
|
|
if vis_diff: |
|
return [u_over, b_over, l_over, r_over], [u_diff, b_diff, l_diff, r_diff] |
|
else: |
|
return [u_over, b_over, l_over, r_over] |
|
|
|
def bilateral_filter(depth, config, discontinuity_map=None, HR=False, mask=None, window_size=False): |
|
sort_time = 0 |
|
replace_time = 0 |
|
filter_time = 0 |
|
init_time = 0 |
|
filtering_time = 0 |
|
sigma_s = config['sigma_s'] |
|
sigma_r = config['sigma_r'] |
|
if window_size == False: |
|
window_size = config['filter_size'] |
|
midpt = window_size//2 |
|
ax = np.arange(-midpt, midpt+1.) |
|
xx, yy = np.meshgrid(ax, ax) |
|
if discontinuity_map is not None: |
|
spatial_term = np.exp(-(xx**2 + yy**2) / (2. * sigma_s**2)) |
|
|
|
|
|
depth = depth[1:-1, 1:-1] |
|
depth = np.pad(depth, ((1,1), (1,1)), 'edge') |
|
pad_depth = np.pad(depth, (midpt,midpt), 'edge') |
|
if discontinuity_map is not None: |
|
discontinuity_map = discontinuity_map[1:-1, 1:-1] |
|
discontinuity_map = np.pad(discontinuity_map, ((1,1), (1,1)), 'edge') |
|
pad_discontinuity_map = np.pad(discontinuity_map, (midpt,midpt), 'edge') |
|
pad_discontinuity_hole = 1 - pad_discontinuity_map |
|
|
|
output = depth.copy() |
|
pad_depth_patches = rolling_window(pad_depth, [window_size, window_size], [1,1]) |
|
if discontinuity_map is not None: |
|
pad_discontinuity_patches = rolling_window(pad_discontinuity_map, [window_size, window_size], [1,1]) |
|
pad_discontinuity_hole_patches = rolling_window(pad_discontinuity_hole, [window_size, window_size], [1,1]) |
|
|
|
if mask is not None: |
|
pad_mask = np.pad(mask, (midpt,midpt), 'constant') |
|
pad_mask_patches = rolling_window(pad_mask, [window_size, window_size], [1,1]) |
|
from itertools import product |
|
if discontinuity_map is not None: |
|
pH, pW = pad_depth_patches.shape[:2] |
|
for pi in range(pH): |
|
for pj in range(pW): |
|
if mask is not None and mask[pi, pj] == 0: |
|
continue |
|
if discontinuity_map is not None: |
|
if bool(pad_discontinuity_patches[pi, pj].any()) is False: |
|
continue |
|
discontinuity_patch = pad_discontinuity_patches[pi, pj] |
|
discontinuity_holes = pad_discontinuity_hole_patches[pi, pj] |
|
depth_patch = pad_depth_patches[pi, pj] |
|
depth_order = depth_patch.ravel().argsort() |
|
patch_midpt = depth_patch[window_size//2, window_size//2] |
|
if discontinuity_map is not None: |
|
coef = discontinuity_holes.astype(np.float32) |
|
if mask is not None: |
|
coef = coef * pad_mask_patches[pi, pj] |
|
else: |
|
range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2)) |
|
coef = spatial_term * range_term |
|
if coef.max() == 0: |
|
output[pi, pj] = patch_midpt |
|
continue |
|
if discontinuity_map is not None and (coef.max() == 0): |
|
output[pi, pj] = patch_midpt |
|
else: |
|
coef = coef/(coef.sum()) |
|
coef_order = coef.ravel()[depth_order] |
|
cum_coef = np.cumsum(coef_order) |
|
ind = np.digitize(0.5, cum_coef) |
|
output[pi, pj] = depth_patch.ravel()[depth_order][ind] |
|
else: |
|
pH, pW = pad_depth_patches.shape[:2] |
|
for pi in range(pH): |
|
for pj in range(pW): |
|
if discontinuity_map is not None: |
|
if pad_discontinuity_patches[pi, pj][window_size//2, window_size//2] == 1: |
|
continue |
|
discontinuity_patch = pad_discontinuity_patches[pi, pj] |
|
discontinuity_holes = (1. - discontinuity_patch) |
|
depth_patch = pad_depth_patches[pi, pj] |
|
depth_order = depth_patch.ravel().argsort() |
|
patch_midpt = depth_patch[window_size//2, window_size//2] |
|
range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2)) |
|
if discontinuity_map is not None: |
|
coef = spatial_term * range_term * discontinuity_holes |
|
else: |
|
coef = spatial_term * range_term |
|
if coef.sum() == 0: |
|
output[pi, pj] = patch_midpt |
|
continue |
|
if discontinuity_map is not None and (coef.sum() == 0): |
|
output[pi, pj] = patch_midpt |
|
else: |
|
coef = coef/(coef.sum()) |
|
coef_order = coef.ravel()[depth_order] |
|
cum_coef = np.cumsum(coef_order) |
|
ind = np.digitize(0.5, cum_coef) |
|
output[pi, pj] = depth_patch.ravel()[depth_order][ind] |
|
|
|
return output |
|
|
|
def rolling_window(a, window, strides): |
|
assert len(a.shape)==len(window)==len(strides), "\'a\', \'window\', \'strides\' dimension mismatch" |
|
shape_fn = lambda i,w,s: (a.shape[i]-w)//s + 1 |
|
shape = [shape_fn(i,w,s) for i,(w,s) in enumerate(zip(window, strides))] + list(window) |
|
def acc_shape(i): |
|
if i+1>=len(a.shape): |
|
return 1 |
|
else: |
|
return reduce(lambda x,y:x*y, a.shape[i+1:]) |
|
_strides = [acc_shape(i)*s*a.itemsize for i,s in enumerate(strides)] + list(a.strides) |
|
|
|
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=_strides) |
|
|