File size: 3,782 Bytes
4ea50ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torchvision

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask

from maskrcnn_benchmark.structures.ke import textKES
from maskrcnn_benchmark.structures.mty import MTY

DEBUG = 0

class WordDataset(torchvision.datasets.coco.CocoDetection):
    def __init__(
        self, ann_file, root, remove_images_without_annotations, transforms=None
    ):
        super(WordDataset, self).__init__(root, ann_file)
        # sort indices for reproducible results
        self.ids = sorted(self.ids)

        # filter images without detection annotations
        if remove_images_without_annotations:
            self.ids = [
                img_id
                for img_id in self.ids
                if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
            ]

        self.json_category_id_to_contiguous_id = {
            v: i + 1 for i, v in enumerate(self.coco.getCatIds())
        }
        self.contiguous_category_id_to_json_id = {
            v: k for k, v in self.json_category_id_to_contiguous_id.items()
        }
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self.transforms = transforms

    def kes_encode(self, kes):
        kes_encode = []
        for i in kes:
            mnx = i[0]
            mny = i[1]
            assert(len(i)%3 == 0)
            npts = int(len(i)/3-2)
            for index in range(npts):
                i[3+index*3]  = (i[3+index*3]+mnx)/2
                i[4+index*3]  = (i[4+index*3]+mny)/2
            kes_encode.append(i)
        return kes_encode

    def kes_gen(self, kes):
        kes_gen_out = []
        for i in kes:
            mnx = i[0]
            mny = i[1]
            cx= i[27]
            cy= i[28]
            assert(len(i)%3 == 0)
            ot = [mnx, i[3],i[6],i[9],i[12], cx,\
                  mny, i[16],i[19],i[22],i[25], cy]
            kes_gen_out.append(ot)
        return kes_gen_out

    def __getitem__(self, idx):
        img, anno = super(WordDataset, self).__getitem__(idx)
        # filter crowd annotations
        # TODO might be better to add an extra field
        anno = [obj for obj in anno if obj["iscrowd"] == 0]

        boxes = [obj["bbox"] for obj in anno]
        if DEBUG: print('len(boxes)', len(boxes), boxes[0])
        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")

        classes = [obj["category_id"] for obj in anno]
        if DEBUG: print('len(classes)', len(classes), classes[0])
        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
        classes = torch.tensor(classes)
        target.add_field("labels", classes)

        masks = [obj["segmentation"] for obj in anno]
        if DEBUG: print('len(masks)', len(masks), masks[0])
        masks = SegmentationMask(masks, img.size)
        target.add_field("masks", masks)

        if anno and 'keypoints' in anno[0]:
            kes = [obj["keypoints"] for obj in anno]
            kes = self.kes_gen(kes)
            if DEBUG: print('len(kes)', len(kes), kes[0])
            kes = textKES(kes, img.size)
            target.add_field("kes", kes)

        if anno and 'match_type' in anno[0]:
            mty = [obj["match_type"] for obj in anno]
            mty = MTY(mty, img.size)
            target.add_field("mty", mty)

        target = target.clip_to_image(remove_empty=True)

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target, idx

    def get_img_info(self, index):
        img_id = self.id_to_img_map[index]
        img_data = self.coco.imgs[img_id]
        return img_data