haakohu's picture
initial
5d756f1
import cv2
import numpy as np
import torch
import tops
from skimage.morphology import disk
from torchvision.transforms.functional import resize, InterpolationMode
from functools import lru_cache
@lru_cache(maxsize=200)
def get_kernel(n: int):
kernel = disk(n, dtype=bool)
return tops.to_cuda(torch.from_numpy(kernel).bool())
def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape):
"""
Transforms the detected embedding/mask directly to the target image shape
"""
C, HE, WE = E.shape
assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox)
assert E_bbox[2] >= exp_bbox[0]
assert E_bbox[1] >= exp_bbox[1]
assert E_bbox[3] >= exp_bbox[1]
assert E_bbox[2] <= exp_bbox[2]
assert E_bbox[3] <= exp_bbox[3]
x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32)
new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool)
E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)
new_E[:, y0:y1, x0:x1] = E
S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0
new_S[y0:y1, x0:x1] = S
return new_E, new_S
def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor):
"""
mask: shape [N, H, W]
"""
assert len(mask1.shape) == 3
assert len(mask2.shape) == 3
assert mask1.device == mask2.device, (mask1.device, mask2.device)
assert mask2.dtype == mask2.dtype
assert mask1.dtype == torch.bool
assert mask1.shape[1:] == mask2.shape[1:]
N1, H1, W1 = mask1.shape
N2, H2, W2 = mask2.shape
iou = torch.zeros((N1, N2), dtype=torch.float32)
for i in range(N1):
cur = mask1[i:i+1]
inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
iou[i] = inter / union
return iou
def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float):
N1 = mask1.shape[0]
N2 = mask2.shape[0]
ious = pairwise_mask_iou(mask1, mask2).cpu().numpy()
indices = np.array([idx for idx, iou in np.ndenumerate(ious)])
ious = ious.flatten()
mask = ious >= iou_threshold
ious = ious[mask]
indices = indices[mask]
# do not sort by iou to keep ordering of mask rcnn / cse sorting.
taken1 = np.zeros((N1), dtype=bool)
taken2 = np.zeros((N2), dtype=bool)
matches = []
for i, j in indices:
if taken1[i].any() or taken2[j].any():
continue
matches.append((i, j))
taken1[i] = True
taken2[j] = True
return matches
def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float):
assert 0 < iou_threshold <= 1
matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold)
H, W = segmentation.shape[1:]
new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device)
cse_im_seg = cse_dets["im_segmentation"]
for idx, (i, j) in enumerate(matches):
new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j])
cse_dets = dict(
instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]],
instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]],
bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]],
scores=cse_dets["scores"][[j for (i, j) in matches]],
)
return new_seg, cse_dets, np.array(matches).reshape(-1, 2)
def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor):
"""
cse_boxes can be outside of segmentation.
"""
boxes = masks_to_boxes(segmentation)
assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape)
combined = torch.stack((boxes, cse_boxes), dim=-1)
boxes = torch.cat((
combined[:, :2].min(dim=2).values,
combined[:, 2:].max(dim=2).values,
), dim=1)
return boxes
def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False):
"""
Crops or pads x to fit in the bbox and resize to target shape.
"""
C, H, W = x.shape
x0, y0, x1, y1 = bbox
if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H:
new_x = x[:, y0:y1, x0:x1]
else:
new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device)
y0_t = max(0, -y0)
y1_t = min(y1-y0, (y1-y0)-(y1-H))
x0_t = max(0, -x0)
x1_t = min(x1-x0, (x1-x0)-(x1-W))
x0 = max(0, x0)
y0 = max(0, y0)
x1 = min(x1, W)
y1 = min(y1, H)
new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1]
# Nearest upsampling often generates more sharp synthesized identities.
interp = InterpolationMode.BICUBIC
if (y1-y0) < target_shape[0] and (x1-x0) < target_shape[1]:
interp = InterpolationMode.NEAREST
antialias = interp == InterpolationMode.BICUBIC
if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]:
return new_x
if x.dtype == torch.bool:
new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5
elif x.dtype == torch.float32:
new_x = resize(new_x, target_shape, interpolation=interp, antialias=antialias)
elif x.dtype == torch.uint8:
if fdf_resize: # FDF dataset is created with cv2 INTER_AREA.
# Incorrect resizing generates noticeable poorer inpaintings.
upsampling = ((y1-y0) * (x1-x0)) < (target_shape[0] * target_shape[1])
if upsampling:
new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC,
antialias=True).round().clamp(0, 255).byte()
else:
device = new_x.device
new_x = new_x.permute(1, 2, 0).cpu().numpy()
new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA)
new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device)
else:
new_x = resize(new_x.float(), target_shape, interpolation=interp,
antialias=antialias).round().clamp(0, 255).byte()
else:
raise ValueError(f"Not supported dtype: {x.dtype}")
return new_x
def masks_to_boxes(segmentation: torch.Tensor):
assert len(segmentation.shape) == 3
x = segmentation.any(dim=1).byte() # Compress rows
x0 = x.argmax(dim=1)
x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1)
y = segmentation.any(dim=2).byte()
y0 = y.argmax(dim=1)
y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1)
return torch.stack([x0, y0, x1, y1], dim=1)