Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import os | |
import os.path | |
import math | |
from PIL import Image, ImageDraw | |
import random | |
import numpy as np | |
import torch | |
import torchvision | |
import torch.utils.data as data | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask | |
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints | |
from maskrcnn_benchmark.config import cfg | |
import pdb | |
def _count_visible_keypoints(anno): | |
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |
def _has_only_empty_bbox(anno): | |
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | |
def has_valid_annotation(anno): | |
# if it's empty, there is no annotation | |
if len(anno) == 0: | |
return False | |
# if all boxes have close to zero area, there is no annotation | |
if _has_only_empty_bbox(anno): | |
return False | |
# keypoints task have a slight different critera for considering | |
# if an annotation is valid | |
if "keypoints" not in anno[0]: | |
return True | |
# for keypoint detection tasks, only consider valid images those | |
# containing at least min_keypoints_per_image | |
if _count_visible_keypoints(anno) >= cfg.DATALOADER.MIN_KPS_PER_IMS: | |
return True | |
return False | |
def pil_loader(path, retry=5): | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
ri = 0 | |
while ri < retry: | |
try: | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
return img.convert('RGB') | |
except: | |
ri += 1 | |
def rgb2id(color): | |
if isinstance(color, np.ndarray) and len(color.shape) == 3: | |
if color.dtype == np.uint8: | |
color = color.astype(np.int32) | |
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] | |
return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) | |
class CocoDetection(data.Dataset): | |
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. | |
Args: | |
root (string): Root directory where images are downloaded to. | |
annFile (string): Path to json annotation file. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.ToTensor`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
""" | |
def __init__(self, root, annFile, transform=None, target_transform=None): | |
from pycocotools.coco import COCO | |
self.root = root | |
self.coco = COCO(annFile) | |
self.ids = list(self.coco.imgs.keys()) | |
self.transform = transform | |
self.target_transform = target_transform | |
def __getitem__(self, index, return_meta=False): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
""" | |
coco = self.coco | |
img_id = self.ids[index] | |
if isinstance(img_id, str): | |
img_id = [img_id] | |
ann_ids = coco.getAnnIds(imgIds=img_id) | |
target = coco.loadAnns(ann_ids) | |
meta = coco.loadImgs(img_id)[0] | |
path = meta['file_name'] | |
img = pil_loader(os.path.join(self.root, path)) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
if return_meta: | |
return img, target, meta | |
else: | |
return img, target | |
def __len__(self): | |
return len(self.ids) | |
def __repr__(self): | |
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
fmt_str += ' Root Location: {}\n'.format(self.root) | |
tmp = ' Transforms (if any): ' | |
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
tmp = ' Target Transforms (if any): ' | |
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
return fmt_str | |
class COCODataset(CocoDetection): | |
def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None, ignore_crowd=True, | |
max_box=-1, | |
few_shot=0, one_hot=False, override_category=None, **kwargs | |
): | |
super(COCODataset, self).__init__(root, ann_file) | |
# sort indices for reproducible results | |
self.ids = sorted(self.ids) | |
# filter images without detection annotations | |
if remove_images_without_annotations: | |
ids = [] | |
for img_id in self.ids: | |
if isinstance(img_id, str): | |
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
else: | |
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
anno = self.coco.loadAnns(ann_ids) | |
if has_valid_annotation(anno): | |
ids.append(img_id) | |
self.ids = ids | |
if few_shot: | |
ids = [] | |
cats_freq = [few_shot]*len(self.coco.cats.keys()) | |
if 'shuffle_seed' in kwargs and kwargs['shuffle_seed'] != 0: | |
import random | |
random.Random(kwargs['shuffle_seed']).shuffle(self.ids) | |
print("Shuffle the dataset with random seed: ", kwargs['shuffle_seed']) | |
for img_id in self.ids: | |
if isinstance(img_id, str): | |
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
else: | |
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
anno = self.coco.loadAnns(ann_ids) | |
cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level | |
is_needed = sum([cats_freq[c-1]>0 for c in cat]) | |
if is_needed: | |
ids.append(img_id) | |
for c in cat: | |
cats_freq[c-1] -= 1 | |
# print(cat, cats_freq) | |
self.ids = ids | |
if override_category is not None: | |
self.coco.dataset["categories"] = override_category | |
print("Override category: ", override_category) | |
self.json_category_id_to_contiguous_id = { | |
v: i + 1 for i, v in enumerate(self.coco.getCatIds()) | |
} | |
self.contiguous_category_id_to_json_id = { | |
v: k for k, v in self.json_category_id_to_contiguous_id.items() | |
} | |
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
self.transforms = transforms | |
self.ignore_crowd = ignore_crowd | |
self.max_box = max_box | |
self.one_hot = one_hot | |
def categories(self, no_background=True): | |
categories = self.coco.dataset["categories"] | |
label_list = {} | |
for index, i in enumerate(categories): | |
if not no_background or (i["name"] != "__background__" and i['id'] != 0): | |
label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"] | |
return label_list | |
def __getitem__(self, idx): | |
img, anno = super(COCODataset, self).__getitem__(idx) | |
# filter crowd annotations | |
if self.ignore_crowd: | |
anno = [obj for obj in anno if obj["iscrowd"] == 0] | |
boxes = [obj["bbox"] for obj in anno] | |
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes | |
if self.max_box > 0 and len(boxes) > self.max_box: | |
rand_idx = torch.randperm(self.max_box) | |
boxes = boxes[rand_idx, :] | |
else: | |
rand_idx = None | |
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy") | |
classes = [obj["category_id"] for obj in anno] | |
classes = [self.json_category_id_to_contiguous_id[c] for c in classes] | |
classes = torch.tensor(classes) | |
if rand_idx is not None: | |
classes = classes[rand_idx] | |
if cfg.DATASETS.CLASS_AGNOSTIC: | |
classes = torch.ones_like(classes) | |
target.add_field("labels", classes) | |
if anno and "segmentation" in anno[0]: | |
masks = [obj["segmentation"] for obj in anno] | |
masks = SegmentationMask(masks, img.size, mode='poly') | |
target.add_field("masks", masks) | |
if anno and "cbox" in anno[0]: | |
cboxes = [obj["cbox"] for obj in anno] | |
cboxes = torch.as_tensor(cboxes).reshape(-1, 4) # guard against no boxes | |
cboxes = BoxList(cboxes, img.size, mode="xywh").convert("xyxy") | |
target.add_field("cbox", cboxes) | |
if anno and "keypoints" in anno[0]: | |
keypoints = [] | |
gt_keypoint = self.coco.cats[1]['keypoints'] # <TODO> a better way to get keypoint description | |
use_keypoint = cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME | |
for obj in anno: | |
if len(use_keypoint) > 0: | |
kps = [] | |
for name in use_keypoint: | |
kp_idx = slice(3 * gt_keypoint.index(name), 3 * gt_keypoint.index(name) + 3) | |
kps += obj["keypoints"][kp_idx] | |
keypoints.append(kps) | |
else: | |
keypoints.append(obj["keypoints"]) | |
keypoints = PersonKeypoints(keypoints, img.size) | |
target.add_field("keypoints", keypoints) | |
target = target.clip_to_image(remove_empty=True) | |
if self.transforms is not None: | |
img, target = self.transforms(img, target) | |
if cfg.DATASETS.SAMPLE_RATIO != 0.0: | |
ratio = cfg.DATASETS.SAMPLE_RATIO | |
num_sample_target = math.ceil(len(target) * ratio) if ratio > 0 else math.ceil(-ratio) | |
sample_idx = torch.randperm(len(target))[:num_sample_target] | |
target = target[sample_idx] | |
return img, target, idx | |
def get_img_info(self, index): | |
img_id = self.id_to_img_map[index] | |
img_data = self.coco.imgs[img_id] | |
return img_data | |