Spaces:
Runtime error
Runtime error
File size: 3,782 Bytes
4ea50ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
import torchvision
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.ke import textKES
from maskrcnn_benchmark.structures.mty import MTY
DEBUG = 0
class WordDataset(torchvision.datasets.coco.CocoDetection):
def __init__(
self, ann_file, root, remove_images_without_annotations, transforms=None
):
super(WordDataset, self).__init__(root, ann_file)
# sort indices for reproducible results
self.ids = sorted(self.ids)
# filter images without detection annotations
if remove_images_without_annotations:
self.ids = [
img_id
for img_id in self.ids
if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
]
self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
}
self.contiguous_category_id_to_json_id = {
v: k for k, v in self.json_category_id_to_contiguous_id.items()
}
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
self.transforms = transforms
def kes_encode(self, kes):
kes_encode = []
for i in kes:
mnx = i[0]
mny = i[1]
assert(len(i)%3 == 0)
npts = int(len(i)/3-2)
for index in range(npts):
i[3+index*3] = (i[3+index*3]+mnx)/2
i[4+index*3] = (i[4+index*3]+mny)/2
kes_encode.append(i)
return kes_encode
def kes_gen(self, kes):
kes_gen_out = []
for i in kes:
mnx = i[0]
mny = i[1]
cx= i[27]
cy= i[28]
assert(len(i)%3 == 0)
ot = [mnx, i[3],i[6],i[9],i[12], cx,\
mny, i[16],i[19],i[22],i[25], cy]
kes_gen_out.append(ot)
return kes_gen_out
def __getitem__(self, idx):
img, anno = super(WordDataset, self).__getitem__(idx)
# filter crowd annotations
# TODO might be better to add an extra field
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
if DEBUG: print('len(boxes)', len(boxes), boxes[0])
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
classes = [obj["category_id"] for obj in anno]
if DEBUG: print('len(classes)', len(classes), classes[0])
classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
classes = torch.tensor(classes)
target.add_field("labels", classes)
masks = [obj["segmentation"] for obj in anno]
if DEBUG: print('len(masks)', len(masks), masks[0])
masks = SegmentationMask(masks, img.size)
target.add_field("masks", masks)
if anno and 'keypoints' in anno[0]:
kes = [obj["keypoints"] for obj in anno]
kes = self.kes_gen(kes)
if DEBUG: print('len(kes)', len(kes), kes[0])
kes = textKES(kes, img.size)
target.add_field("kes", kes)
if anno and 'match_type' in anno[0]:
mty = [obj["match_type"] for obj in anno]
mty = MTY(mty, img.size)
target.add_field("mty", mty)
target = target.clip_to_image(remove_empty=True)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target, idx
def get_img_info(self, index):
img_id = self.id_to_img_map[index]
img_data = self.coco.imgs[img_id]
return img_data
|