OMG_Seg / ext /davis2017 /utils.py
Haobo Yuan
add omg code
b34d1d6
raw
history blame
6.01 kB
import os
import errno
import numpy as np
from PIL import Image
import warnings
from ext.davis2017.davis import DAVIS
def _pascal_color_map(N=256, normalized=False):
"""
Python implementation of the color map function for the PASCAL VOC data set.
Official Matlab version can be found in the PASCAL VOC devkit
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
"""
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7 - j)
g = g | (bitget(c, 1) << 7 - j)
b = b | (bitget(c, 2) << 7 - j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap / 255 if normalized else cmap
return cmap
def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):
im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)
if im.shape[:-1] != ann.shape:
raise ValueError('First two dimensions of `im` and `ann` must match')
if im.shape[-1] != 3:
raise ValueError('im must have three channels at the 3 dimension')
colors = colors or _pascal_color_map()
colors = np.asarray(colors, dtype=np.uint8)
mask = colors[ann]
fg = im * alpha + (1 - alpha) * mask
img = im.copy()
img[ann > 0] = fg[ann > 0]
if contour_thickness: # pragma: no cover
import cv2
for obj_id in np.unique(ann[ann > 0]):
contours = cv2.findContours((ann == obj_id).astype(
np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),
contour_thickness)
return img
def generate_obj_proposals(davis_root, subset, num_proposals, save_path):
dataset = DAVIS(davis_root, subset=subset, codalab=True)
for seq in dataset.get_sequences():
save_dir = os.path.join(save_path, seq)
if os.path.exists(save_dir):
continue
all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
img_size = all_gt_masks.shape[2:]
num_rows = int(np.ceil(np.sqrt(num_proposals)))
proposals = np.zeros((num_proposals, len(all_masks_id), *img_size))
height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist()
width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist()
ii = 0
prev_h, prev_w = 0, 0
for h in height_slices[1:]:
for w in width_slices[1:]:
proposals[ii, :, prev_h:h, prev_w:w] = 1
prev_w = w
ii += 1
if ii == num_proposals:
break
prev_h, prev_w = h, 0
if ii == num_proposals:
break
os.makedirs(save_dir, exist_ok=True)
for i, mask_id in enumerate(all_masks_id):
mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0)
save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path):
dataset = DAVIS(davis_root, subset=subset, codalab=True)
for seq in dataset.get_sequences():
gt_masks, all_masks_id = dataset.get_all_masks(seq, True)
obj_swap = np.random.permutation(np.arange(gt_masks.shape[0]))
gt_masks = gt_masks[obj_swap, ...]
save_dir = os.path.join(save_path, seq)
os.makedirs(save_dir, exist_ok=True)
for i, mask_id in enumerate(all_masks_id):
mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0)
save_mask(mask, os.path.join(save_dir, f'{mask_id}.png'))
def color_map(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap/255 if normalized else cmap
return cmap
def save_mask(mask, img_path):
if np.max(mask) > 255:
raise ValueError('Maximum id pixel value is 255')
mask_img = Image.fromarray(mask.astype(np.uint8))
mask_img.putpalette(color_map().flatten().tolist())
mask_img.save(img_path)
def db_statistics(per_frame_values):
""" Compute mean,recall and decay from per-frame evaluation.
Arguments:
per_frame_values (ndarray): per-frame evaluation
Returns:
M,O,D (float,float,float):
return evaluation statistics: mean,recall,decay.
"""
# strip off nan values
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
M = np.nanmean(per_frame_values)
O = np.nanmean(per_frame_values > 0.5)
N_bins = 4
ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1
ids = ids.astype(np.uint8)
D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3])
return M, O, D
def list_files(dir, extension=".png"):
return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)]
def force_symlink(file1, file2):
try:
os.symlink(file1, file2)
except OSError as e:
if e.errno == errno.EEXIST:
os.remove(file2)
os.symlink(file1, file2)