Spaces:
Runtime error
Runtime error
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import json | |
import os | |
import time | |
from collections import defaultdict | |
import pycocotools.mask as mask_utils | |
import torchvision | |
from PIL import Image | |
# from .coco import ConvertCocoPolysToMask, make_coco_transforms | |
from .modulated_coco import ConvertCocoPolysToMask | |
def _isArrayLike(obj): | |
return hasattr(obj, "__iter__") and hasattr(obj, "__len__") | |
class LVIS: | |
def __init__(self, annotation_path=None): | |
"""Class for reading and visualizing annotations. | |
Args: | |
annotation_path (str): location of annotation file | |
""" | |
self.anns = {} | |
self.cats = {} | |
self.imgs = {} | |
self.img_ann_map = defaultdict(list) | |
self.cat_img_map = defaultdict(list) | |
self.dataset = {} | |
if annotation_path is not None: | |
print("Loading annotations.") | |
tic = time.time() | |
self.dataset = self._load_json(annotation_path) | |
print("Done (t={:0.2f}s)".format(time.time() - tic)) | |
assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset)) | |
self._create_index() | |
def _load_json(self, path): | |
with open(path, "r") as f: | |
return json.load(f) | |
def _create_index(self): | |
print("Creating index.") | |
self.img_ann_map = defaultdict(list) | |
self.cat_img_map = defaultdict(list) | |
self.anns = {} | |
self.cats = {} | |
self.imgs = {} | |
for ann in self.dataset["annotations"]: | |
self.img_ann_map[ann["image_id"]].append(ann) | |
self.anns[ann["id"]] = ann | |
for img in self.dataset["images"]: | |
self.imgs[img["id"]] = img | |
for cat in self.dataset["categories"]: | |
self.cats[cat["id"]] = cat | |
for ann in self.dataset["annotations"]: | |
self.cat_img_map[ann["category_id"]].append(ann["image_id"]) | |
print("Index created.") | |
def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None): | |
"""Get ann ids that satisfy given filter conditions. | |
Args: | |
img_ids (int array): get anns for given imgs | |
cat_ids (int array): get anns for given cats | |
area_rng (float array): get anns for a given area range. e.g [0, inf] | |
Returns: | |
ids (int array): integer array of ann ids | |
""" | |
if img_ids is not None: | |
img_ids = img_ids if _isArrayLike(img_ids) else [img_ids] | |
if cat_ids is not None: | |
cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids] | |
anns = [] | |
if img_ids is not None: | |
for img_id in img_ids: | |
anns.extend(self.img_ann_map[img_id]) | |
else: | |
anns = self.dataset["annotations"] | |
# return early if no more filtering required | |
if cat_ids is None and area_rng is None: | |
return [_ann["id"] for _ann in anns] | |
cat_ids = set(cat_ids) | |
if area_rng is None: | |
area_rng = [0, float("inf")] | |
ann_ids = [ | |
_ann["id"] | |
for _ann in anns | |
if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1] | |
] | |
return ann_ids | |
def get_cat_ids(self): | |
"""Get all category ids. | |
Returns: | |
ids (int array): integer array of category ids | |
""" | |
return list(self.cats.keys()) | |
def get_img_ids(self): | |
"""Get all img ids. | |
Returns: | |
ids (int array): integer array of image ids | |
""" | |
return list(self.imgs.keys()) | |
def _load_helper(self, _dict, ids): | |
if ids is None: | |
return list(_dict.values()) | |
elif _isArrayLike(ids): | |
return [_dict[id] for id in ids] | |
else: | |
return [_dict[ids]] | |
def load_anns(self, ids=None): | |
"""Load anns with the specified ids. If ids=None load all anns. | |
Args: | |
ids (int array): integer array of annotation ids | |
Returns: | |
anns (dict array) : loaded annotation objects | |
""" | |
return self._load_helper(self.anns, ids) | |
def load_cats(self, ids): | |
"""Load categories with the specified ids. If ids=None load all | |
categories. | |
Args: | |
ids (int array): integer array of category ids | |
Returns: | |
cats (dict array) : loaded category dicts | |
""" | |
return self._load_helper(self.cats, ids) | |
def load_imgs(self, ids): | |
"""Load categories with the specified ids. If ids=None load all images. | |
Args: | |
ids (int array): integer array of image ids | |
Returns: | |
imgs (dict array) : loaded image dicts | |
""" | |
return self._load_helper(self.imgs, ids) | |
def download(self, save_dir, img_ids=None): | |
"""Download images from mscoco.org server. | |
Args: | |
save_dir (str): dir to save downloaded images | |
img_ids (int array): img ids of images to download | |
""" | |
imgs = self.load_imgs(img_ids) | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
for img in imgs: | |
file_name = os.path.join(save_dir, img["file_name"]) | |
if not os.path.exists(file_name): | |
from urllib.request import urlretrieve | |
urlretrieve(img["coco_url"], file_name) | |
def ann_to_rle(self, ann): | |
"""Convert annotation which can be polygons, uncompressed RLE to RLE. | |
Args: | |
ann (dict) : annotation object | |
Returns: | |
ann (rle) | |
""" | |
img_data = self.imgs[ann["image_id"]] | |
h, w = img_data["height"], img_data["width"] | |
segm = ann["segmentation"] | |
if isinstance(segm, list): | |
# polygon -- a single object might consist of multiple parts | |
# we merge all parts into one mask rle code | |
rles = mask_utils.frPyObjects(segm, h, w) | |
rle = mask_utils.merge(rles) | |
elif isinstance(segm["counts"], list): | |
# uncompressed RLE | |
rle = mask_utils.frPyObjects(segm, h, w) | |
else: | |
# rle | |
rle = ann["segmentation"] | |
return rle | |
def ann_to_mask(self, ann): | |
"""Convert annotation which can be polygons, uncompressed RLE, or RLE | |
to binary mask. | |
Args: | |
ann (dict) : annotation object | |
Returns: | |
binary mask (numpy 2D array) | |
""" | |
rle = self.ann_to_rle(ann) | |
return mask_utils.decode(rle) | |
class LvisDetectionBase(torchvision.datasets.VisionDataset): | |
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): | |
super(LvisDetectionBase, self).__init__(root, transforms, transform, target_transform) | |
self.lvis = LVIS(annFile) | |
self.ids = list(sorted(self.lvis.imgs.keys())) | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
""" | |
lvis = self.lvis | |
img_id = self.ids[index] | |
ann_ids = lvis.get_ann_ids(img_ids=img_id) | |
target = lvis.load_anns(ann_ids) | |
path = "/".join(self.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:]) | |
img = Image.open(os.path.join(self.root, path)).convert("RGB") | |
if self.transforms is not None: | |
img, target = self.transforms(img, target) | |
return img, target | |
def __len__(self): | |
return len(self.ids) | |
class LvisDetection(LvisDetectionBase): | |
def __init__(self, img_folder, ann_file, transforms, return_masks=False, **kwargs): | |
super(LvisDetection, self).__init__(img_folder, ann_file) | |
self.ann_file = ann_file | |
self._transforms = transforms | |
self.prepare = ConvertCocoPolysToMask(return_masks) | |
def __getitem__(self, idx): | |
img, target = super(LvisDetection, self).__getitem__(idx) | |
image_id = self.ids[idx] | |
target = {"image_id": image_id, "annotations": target} | |
img, target = self.prepare(img, target) | |
if self._transforms is not None: | |
img = self._transforms(img) | |
return img, target, idx | |
def get_raw_image(self, idx): | |
img, target = super(LvisDetection, self).__getitem__(idx) | |
return img | |
def categories(self): | |
id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]} | |
all_cats = sorted(list(id2cat.keys())) | |
categories = {} | |
for l in list(all_cats): | |
categories[l] = id2cat[l]['name'] | |
return categories |