Spaces:
Running
on
L40S
Running
on
L40S
File size: 6,259 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
'''
crop
for torch tensor
Given image, bbox(center, bboxsize)
return: cropped image, tform(used for transform the keypoint accordingly)
only support crop to squared images
'''
import torch
from kornia.geometry.transform.imgwarp import (warp_perspective,
get_perspective_transform,
warp_affine)
def points2bbox(points, points_scale=None):
if points_scale:
assert points_scale[0] == points_scale[1]
points = points.clone()
points[:, :, :2] = (points[:, :, :2] * 0.5 + 0.5) * points_scale[0]
min_coords, _ = torch.min(points, dim=1)
xmin, ymin = min_coords[:, 0], min_coords[:, 1]
max_coords, _ = torch.max(points, dim=1)
xmax, ymax = max_coords[:, 0], max_coords[:, 1]
center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5
width = (xmax - xmin)
height = (ymax - ymin)
# Convert the bounding box to a square box
size = torch.max(width, height).unsqueeze(-1)
return center, size
def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.):
batch_size = center.shape[0]
trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2. -
1.) * trans_scale
center = center + trans_scale * bbox_size # 0.5
scale = torch.rand([batch_size, 1], device=center.device) * \
(scale[1] - scale[0]) + scale[0]
size = bbox_size * scale
return center, size
def crop_tensor(image,
center,
bbox_size,
crop_size,
interpolation='bilinear',
align_corners=False):
''' for batch image
Args:
image (torch.Tensor): the reference tensor of shape BXHxWXC.
center: [bz, 2]
bboxsize: [bz, 1]
crop_size;
interpolation (str): Interpolation flag. Default: 'bilinear'.
align_corners (bool): mode for grid_generation. Default: False. See
https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
Returns:
cropped_image
tform
'''
dtype = image.dtype
device = image.device
batch_size = image.shape[0]
# points: top-left, top-right, bottom-right, bottom-left
src_pts = torch.zeros([4, 2], dtype=dtype,
device=device).unsqueeze(0).expand(
batch_size, -1, -1).contiguous()
src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1)
src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
src_pts[:, 2, :] = center + bbox_size * 0.5
src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5
DST_PTS = torch.tensor([[
[0, 0],
[crop_size - 1, 0],
[crop_size - 1, crop_size - 1],
[0, crop_size - 1],
]],
dtype=dtype,
device=device).expand(batch_size, -1, -1)
# estimate transformation between points
dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
# simulate broadcasting
# dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)
# warp images
cropped_image = warp_affine(image,
dst_trans_src[:, :2, :],
(crop_size, crop_size),
mode=interpolation,
align_corners=align_corners)
tform = torch.transpose(dst_trans_src, 2, 1)
# tform = torch.inverse(dst_trans_src)
return cropped_image, tform
class Cropper(object):
def __init__(self, crop_size, scale=[1, 1], trans_scale=0.):
self.crop_size = crop_size
self.scale = scale
self.trans_scale = trans_scale
def crop(self, image, points, points_scale=None):
# points to bbox
center, bbox_size = points2bbox(points.clone(), points_scale)
# argument bbox. TODO: add rotation?
center, bbox_size = augment_bbox(center,
bbox_size,
scale=self.scale,
trans_scale=self.trans_scale)
# crop
cropped_image, tform = crop_tensor(image, center, bbox_size,
self.crop_size)
return cropped_image, tform
def transform_points(self,
points,
tform,
points_scale=None,
normalize=True):
points_2d = points[:, :, :2]
#'input points must use original range'
if points_scale:
assert points_scale[0] == points_scale[1]
points_2d = (points_2d * 0.5 + 0.5) * points_scale[0]
batch_size, n_points, _ = points.shape
trans_points_2d = torch.bmm(
torch.cat([
points_2d,
torch.ones([batch_size, n_points, 1],
device=points.device,
dtype=points.dtype)
],
dim=-1), tform)
trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]],
dim=-1)
if normalize:
trans_points[:, :, :2] = trans_points[:, :, :2] / \
self.crop_size*2 - 1
return trans_points
def transform_points(points, tform, points_scale=None):
points_2d = points[:, :, :2]
#'input points must use original range'
if points_scale:
assert points_scale[0] == points_scale[1]
points_2d = (points_2d * 0.5 + 0.5) * points_scale[0]
# import ipdb; ipdb.set_trace()
batch_size, n_points, _ = points.shape
trans_points_2d = torch.bmm(
torch.cat([
points_2d,
torch.ones([batch_size, n_points, 1],
device=points.device,
dtype=points.dtype)
],
dim=-1), tform)
trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]],
dim=-1)
return trans_points
|