File size: 7,213 Bytes
5d756f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import cv2
import numpy as np
import torch
import tops
from skimage.morphology import disk
from torchvision.transforms.functional import resize, InterpolationMode
from functools import lru_cache


@lru_cache(maxsize=200)
def get_kernel(n: int):
    kernel = disk(n, dtype=bool)
    return tops.to_cuda(torch.from_numpy(kernel).bool())


def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape):
    """
        Transforms the detected embedding/mask directly to the target image shape
    """

    C, HE, WE = E.shape
    assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox)
    assert E_bbox[2] >= exp_bbox[0]
    assert E_bbox[1] >= exp_bbox[1]
    assert E_bbox[3] >= exp_bbox[1]
    assert E_bbox[2] <= exp_bbox[2]
    assert E_bbox[3] <= exp_bbox[3]

    x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
    x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
    y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
    y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
    new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32)
    new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool)

    E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)
    new_E[:, y0:y1, x0:x1] = E
    S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0
    new_S[y0:y1, x0:x1] = S
    return new_E, new_S


def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor):
    """
        mask: shape [N, H, W]
    """
    assert len(mask1.shape) == 3
    assert len(mask2.shape) == 3
    assert mask1.device == mask2.device, (mask1.device, mask2.device)
    assert mask2.dtype == mask2.dtype
    assert mask1.dtype == torch.bool
    assert mask1.shape[1:] == mask2.shape[1:]
    N1, H1, W1 = mask1.shape
    N2, H2, W2 = mask2.shape
    iou = torch.zeros((N1, N2), dtype=torch.float32)
    for i in range(N1):
        cur = mask1[i:i+1]
        inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
        union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
        iou[i] = inter / union
    return iou


def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float):
    N1 = mask1.shape[0]
    N2 = mask2.shape[0]
    ious = pairwise_mask_iou(mask1, mask2).cpu().numpy()
    indices = np.array([idx for idx, iou in np.ndenumerate(ious)])
    ious = ious.flatten()
    mask = ious >= iou_threshold
    ious = ious[mask]
    indices = indices[mask]

    # do not sort by iou to keep ordering of mask rcnn / cse sorting.
    taken1 = np.zeros((N1), dtype=bool)
    taken2 = np.zeros((N2), dtype=bool)
    matches = []
    for i, j in indices:
        if taken1[i].any() or taken2[j].any():
            continue
        matches.append((i, j))
        taken1[i] = True
        taken2[j] = True
    return matches


def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float):
    assert 0 < iou_threshold <= 1
    matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold)
    H, W = segmentation.shape[1:]
    new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device)
    cse_im_seg = cse_dets["im_segmentation"]
    for idx, (i, j) in enumerate(matches):
        new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j])
    cse_dets = dict(
        instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]],
        instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]],
        bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]],
        scores=cse_dets["scores"][[j for (i, j) in matches]],
    )
    return new_seg, cse_dets, np.array(matches).reshape(-1, 2)


def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor):
    """
        cse_boxes can be outside of segmentation.
    """
    boxes = masks_to_boxes(segmentation)

    assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape)
    combined = torch.stack((boxes, cse_boxes), dim=-1)
    boxes = torch.cat((
        combined[:, :2].min(dim=2).values,
        combined[:, 2:].max(dim=2).values,
    ), dim=1)
    return boxes


def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False):
    """
        Crops or pads x to fit in the bbox and resize to target shape.
    """
    C, H, W = x.shape
    x0, y0, x1, y1 = bbox

    if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H:
        new_x = x[:, y0:y1, x0:x1]
    else:
        new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device)
        y0_t = max(0, -y0)
        y1_t = min(y1-y0, (y1-y0)-(y1-H))
        x0_t = max(0, -x0)
        x1_t = min(x1-x0, (x1-x0)-(x1-W))
        x0 = max(0, x0)
        y0 = max(0, y0)
        x1 = min(x1, W)
        y1 = min(y1, H)
        new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1]
    # Nearest upsampling often generates more sharp synthesized identities.
    interp = InterpolationMode.BICUBIC
    if (y1-y0) < target_shape[0] and (x1-x0) < target_shape[1]:
        interp = InterpolationMode.NEAREST
    antialias = interp == InterpolationMode.BICUBIC
    if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]:
        return new_x
    if x.dtype == torch.bool:
        new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5
    elif x.dtype == torch.float32:
        new_x = resize(new_x, target_shape, interpolation=interp, antialias=antialias)
    elif x.dtype == torch.uint8:
        if fdf_resize:  # FDF dataset is created with cv2 INTER_AREA.
            # Incorrect resizing generates noticeable poorer inpaintings.
            upsampling = ((y1-y0) * (x1-x0)) < (target_shape[0] * target_shape[1])
            if upsampling:
                new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC,
                               antialias=True).round().clamp(0, 255).byte()
            else:
                device = new_x.device
                new_x = new_x.permute(1, 2, 0).cpu().numpy()
                new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA)
                new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device)
        else:
            new_x = resize(new_x.float(), target_shape, interpolation=interp,
                           antialias=antialias).round().clamp(0, 255).byte()
    else:
        raise ValueError(f"Not supported dtype: {x.dtype}")
    return new_x


def masks_to_boxes(segmentation: torch.Tensor):
    assert len(segmentation.shape) == 3
    x = segmentation.any(dim=1).byte()  # Compress rows
    x0 = x.argmax(dim=1)

    x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1)
    y = segmentation.any(dim=2).byte()
    y0 = y.argmax(dim=1)
    y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1)
    return torch.stack([x0, y0, x1, y1], dim=1)