Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
import pycocotools.mask as mask_utils | |
# transpose | |
FLIP_LEFT_RIGHT = 0 | |
FLIP_TOP_BOTTOM = 1 | |
class Mask(object): | |
""" | |
This class is unfinished and not meant for use yet | |
It is supposed to contain the mask for an object as | |
a 2d tensor | |
""" | |
def __init__(self, masks, size, mode): | |
self.masks = masks | |
self.size = size | |
self.mode = mode | |
def transpose(self, method): | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
width, height = self.size | |
if method == FLIP_LEFT_RIGHT: | |
dim = width | |
idx = 2 | |
elif method == FLIP_TOP_BOTTOM: | |
dim = height | |
idx = 1 | |
flip_idx = list(range(dim)[::-1]) | |
flipped_masks = self.masks.index_select(dim, flip_idx) | |
return Mask(flipped_masks, self.size, self.mode) | |
def crop(self, box): | |
w, h = box[2] - box[0], box[3] - box[1] | |
cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]] | |
return Mask(cropped_masks, size=(w, h), mode=self.mode) | |
def resize(self, size, *args, **kwargs): | |
pass | |
class Polygons(object): | |
""" | |
This class holds a set of polygons that represents a single instance | |
of an object mask. The object can be represented as a set of | |
polygons | |
""" | |
def __init__(self, polygons, size, mode): | |
# assert isinstance(polygons, list), '{}'.format(polygons) | |
if isinstance(polygons, list): | |
polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] | |
elif isinstance(polygons, Polygons): | |
polygons = polygons.polygons | |
self.polygons = polygons | |
self.size = size | |
self.mode = mode | |
def transpose(self, method): | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
flipped_polygons = [] | |
width, height = self.size | |
if method == FLIP_LEFT_RIGHT: | |
dim = width | |
idx = 0 | |
elif method == FLIP_TOP_BOTTOM: | |
dim = height | |
idx = 1 | |
for poly in self.polygons: | |
p = poly.clone() | |
TO_REMOVE = 1 | |
p[idx::2] = dim - poly[idx::2] - TO_REMOVE | |
flipped_polygons.append(p) | |
return Polygons(flipped_polygons, size=self.size, mode=self.mode) | |
def crop(self, box): | |
w, h = box[2] - box[0], box[3] - box[1] | |
# TODO chck if necessary | |
w = max(w, 1) | |
h = max(h, 1) | |
cropped_polygons = [] | |
for poly in self.polygons: | |
p = poly.clone() | |
p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) | |
p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) | |
cropped_polygons.append(p) | |
return Polygons(cropped_polygons, size=(w, h), mode=self.mode) | |
def resize(self, size, *args, **kwargs): | |
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
if ratios[0] == ratios[1]: | |
ratio = ratios[0] | |
scaled_polys = [p * ratio for p in self.polygons] | |
return Polygons(scaled_polys, size, mode=self.mode) | |
ratio_w, ratio_h = ratios | |
scaled_polygons = [] | |
for poly in self.polygons: | |
p = poly.clone() | |
p[0::2] *= ratio_w | |
p[1::2] *= ratio_h | |
scaled_polygons.append(p) | |
return Polygons(scaled_polygons, size=size, mode=self.mode) | |
def convert(self, mode): | |
width, height = self.size | |
if mode == "mask": | |
rles = mask_utils.frPyObjects( | |
[p.detach().numpy() for p in self.polygons], height, width | |
) | |
rle = mask_utils.merge(rles) | |
mask = mask_utils.decode(rle) | |
mask = torch.from_numpy(mask) | |
# TODO add squeeze? | |
return mask | |
def __repr__(self): | |
s = self.__class__.__name__ + "(" | |
s += "num_polygons={}, ".format(len(self.polygons)) | |
s += "image_width={}, ".format(self.size[0]) | |
s += "image_height={}, ".format(self.size[1]) | |
s += "mode={})".format(self.mode) | |
return s | |
class SegmentationMask(object): | |
""" | |
This class stores the segmentations for all objects in the image | |
""" | |
def __init__(self, polygons, size, mode=None): | |
""" | |
Arguments: | |
polygons: a list of list of lists of numbers. The first | |
level of the list correspond to individual instances, | |
the second level to all the polygons that compose the | |
object, and the third level to the polygon coordinates. | |
""" | |
assert isinstance(polygons, list) | |
self.polygons = [Polygons(p, size, mode) for p in polygons] | |
self.size = size | |
self.mode = mode | |
def transpose(self, method): | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
flipped = [] | |
for polygon in self.polygons: | |
flipped.append(polygon.transpose(method)) | |
return SegmentationMask(flipped, size=self.size, mode=self.mode) | |
def crop(self, box): | |
w, h = box[2] - box[0], box[3] - box[1] | |
cropped = [] | |
for polygon in self.polygons: | |
cropped.append(polygon.crop(box)) | |
return SegmentationMask(cropped, size=(w, h), mode=self.mode) | |
def resize(self, size, *args, **kwargs): | |
scaled = [] | |
for polygon in self.polygons: | |
scaled.append(polygon.resize(size, *args, **kwargs)) | |
return SegmentationMask(scaled, size=size, mode=self.mode) | |
def to(self, *args, **kwargs): | |
return self | |
def __getitem__(self, item): | |
if isinstance(item, (int, slice)): | |
selected_polygons = [self.polygons[item]] | |
else: | |
# advanced indexing on a single dimension | |
selected_polygons = [] | |
if isinstance(item, torch.Tensor) and item.dtype == torch.bool: | |
item = item.nonzero() | |
item = item.squeeze(1) if item.numel() > 0 else item | |
item = item.tolist() | |
for i in item: | |
selected_polygons.append(self.polygons[i]) | |
return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) | |
def __iter__(self): | |
return iter(self.polygons) | |
def __repr__(self): | |
s = self.__class__.__name__ + "(" | |
s += "num_instances={}, ".format(len(self.polygons)) | |
s += "image_width={}, ".format(self.size[0]) | |
s += "image_height={})".format(self.size[1]) | |
return s | |